From 7f0b82af2705fae661ba5f672a8073c653948f14 Mon Sep 17 00:00:00 2001 From: Yakov Olkhovskiy <99031427+yakov-olkhovskiy@users.noreply.github.com> Date: Tue, 14 Apr 2026 20:20:08 +0000 Subject: [PATCH 1/4] Cherry-pick of https://github.com/ClickHouse/ClickHouse/pull/91170 with unresolved conflict markers (resolution in next commit) --- Original cherry-pick message follows: Merge pull request #91170 from ClickHouse/feat-arrowflight-impl Add Arrow Flight SQL support # Conflicts: # ci/jobs/scripts/check_style/check_cpp.sh # src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp # src/Server/ArrowFlightHandler.cpp --- ci/jobs/scripts/check_style/check_cpp.sh | 44 + contrib/arrow-cmake/flight.cmake | 41 +- .../grpc-cmake/protobuf_generate_grpc.cmake | 13 +- programs/server/Server.cpp | 4 +- src/CMakeLists.txt | 3 + src/Core/FormatFactorySettings.h | 6 + src/Core/SettingsChangesHistory.cpp | 22 + src/Formats/FormatFactory.cpp | 2 + src/Formats/FormatSettings.h | 2 + .../Formats/Impl/ArrowBlockOutputFormat.cpp | 1 + .../Formats/Impl/CHColumnToArrowColumn.cpp | 772 ++++++++++- .../Formats/Impl/CHColumnToArrowColumn.h | 25 +- .../Formats/Impl/ParquetBlockOutputFormat.cpp | 3 +- src/Server/ArrowFlight/ArrowFlightServer.cpp | 1228 +++++++++++++++++ .../ArrowFlightServer.h} | 12 +- src/Server/ArrowFlight/AuthMiddleware.cpp | 256 ++++ src/Server/ArrowFlight/AuthMiddleware.h | 96 ++ src/Server/ArrowFlight/CallsData.cpp | 567 ++++++++ src/Server/ArrowFlight/CallsData.h | 222 +++ src/Server/ArrowFlight/PollSession.cpp | 91 ++ src/Server/ArrowFlight/PollSession.h | 59 + src/Server/ArrowFlight/commandSelector.cpp | 705 ++++++++++ src/Server/ArrowFlight/commandSelector.h | 84 ++ src/Server/grpc_protos/CMakeLists.txt | 3 +- .../flight_sql_client.py | 486 +++++++ .../test_arrowflight_interface/test.py | 10 +- .../test_sql_server.py | 871 ++++++++++++ .../04070_arrow_complex_types.reference | 43 + .../0_stateless/04070_arrow_complex_types.sh | 202 +++ 29 files changed, 5770 insertions(+), 103 deletions(-) create mode 100644 src/Server/ArrowFlight/ArrowFlightServer.cpp rename src/Server/{ArrowFlightHandler.h => ArrowFlight/ArrowFlightServer.h} (88%) create mode 100644 src/Server/ArrowFlight/AuthMiddleware.cpp create mode 100644 src/Server/ArrowFlight/AuthMiddleware.h create mode 100644 src/Server/ArrowFlight/CallsData.cpp create mode 100644 src/Server/ArrowFlight/CallsData.h create mode 100644 src/Server/ArrowFlight/PollSession.cpp create mode 100644 src/Server/ArrowFlight/PollSession.h create mode 100644 src/Server/ArrowFlight/commandSelector.cpp create mode 100644 src/Server/ArrowFlight/commandSelector.h create mode 100644 tests/integration/test_arrowflight_interface/flight_sql_client.py create mode 100644 tests/integration/test_arrowflight_interface/test_sql_server.py create mode 100644 tests/queries/0_stateless/04070_arrow_complex_types.reference create mode 100755 tests/queries/0_stateless/04070_arrow_complex_types.sh diff --git a/ci/jobs/scripts/check_style/check_cpp.sh b/ci/jobs/scripts/check_style/check_cpp.sh index 3b0000f80f16..621ba855f373 100755 --- a/ci/jobs/scripts/check_style/check_cpp.sh +++ b/ci/jobs/scripts/check_style/check_cpp.sh @@ -352,6 +352,50 @@ PATTERN="allow_"; DIFF=$(comm -3 <(grep -o "\b$PATTERN\w*\b" $ROOT_PATH/src/Core/Settings.cpp | sort -u) <(grep -o -h "\b$PATTERN\w*\b" $ROOT_PATH/src/Databases/enableAllExperimentalSettings.cpp $ROOT_PATH/ci/jobs/scripts/check_style/experimental_settings_ignore.txt | sort -u)); [ -n "$DIFF" ] && echo "$DIFF" && echo "^^ Detected 'allow_*' settings that might need to be included in src/Databases/enableAllExperimentalSettings.cpp" && echo "Alternatively, consider adding an exception to ci/jobs/scripts/check_style/experimental_settings_ignore.txt" +<<<<<<< HEAD +======= +# 12a: NDEBUG and cast checks on nobase_all +{ +# A small typo can lead to debug code in release builds, see https://github.com/ClickHouse/ClickHouse/pull/47647 +xargs < "$STYLE_TMPDIR/nobase_all" grep -l -F '#ifdef NDEBUG' | \ + xargs awk '/#ifdef NDEBUG/ { inside = 1; dirty = 1 } /#endif/ { if (inside && dirty) { print "File " FILENAME " has suspicious #ifdef NDEBUG, possibly confused with #ifndef NDEBUG" }; inside = 0 } /#else/ { dirty = 0 }' + +# If a user is doing dynamic or typeid cast with a pointer, and immediately dereferencing it, it is unsafe. +xargs < "$STYLE_TMPDIR/nobase_all" rg --line-number '(dynamic|typeid)_cast<[^>]+\*>\([^\(\)]+\)->' | grep . && echo "It's suspicious when you are doing a dynamic_cast or typeid_cast with a pointer and immediately dereferencing it. Use references instead of pointers or check a pointer to nullptr." +} > "$O.12a" 2>&1 & + +# 12b: Punctuation, std::regex, and Cyrillic checks on nobase_all +{ +# Check for bad punctuation: whitespace before comma. +xargs < "$STYLE_TMPDIR/nobase_all" rg --line-number '\w ,' | grep -v 'bad punctuation is ok here' && echo "^ There is bad punctuation: whitespace before comma. You should write it like this: 'Hello, world!'" + +# Check usage of std::regex which is too bloated and slow. +xargs < "$STYLE_TMPDIR/nobase_all" grep -F --line-number 'std::regex' | grep . && echo "^ Please use re2 instead of std::regex" + +# Cyrillic characters hiding inside Latin. +grep -v StorageSystemContributors.generated.cpp "$STYLE_TMPDIR/nobase_all" | \ + xargs rg --line-number '[a-zA-Z][а-яА-ЯёЁ]|[а-яА-ЯёЁ][a-zA-Z]' && echo "^ Cyrillic characters found in unexpected place." +} > "$O.12b" 2>&1 & + +# 13: Orphaned header files +{ +join -v1 <(grep '\.h$' "$STYLE_TMPDIR/nobase_all" | sed 's:.*/::' | sort -u) <(rg --no-filename -o '[\w-]+\.h' --glob '*.cpp' --glob '*.c' --glob '*.h' --glob '*.S' $ROOT_PATH/src $ROOT_PATH/programs $ROOT_PATH/utils $ROOT_PATH/tests/lexer | sort -u) | + grep . && echo '^ Found orphan header files.' +} > "$O.13" 2>&1 & + +# 14: Abbreviation checks and error message style +{ +# Wrong spelling of abbreviations, e.g. SQL is right, Sql is wrong. XMLHttpRequest is very wrong. +xargs < "$STYLE_TMPDIR/all_excluded" rg 'Sql|Html|Xml|Cpu|Tcp|Udp|Http|Db|Json|Yaml' | grep -v -E 'RabbitMQ|Azure|Aws|aws|Avro|IO/S3|ai::JsonValue|IcebergWrites|arrow::flight|SqlInfo|CommandGetSqlInfo|CommandGetDbSchemas|commandGetDbSchemas|ArrowFlightSql|TcpExtListenOverflows' && + echo "Abbreviations such as SQL, XML, HTTP, should be in all caps. For example, SQL is right, Sql is wrong. XMLHttpRequest is very wrong." + +xargs < "$STYLE_TMPDIR/all_excluded" grep -F -i 'ErrorCodes::LOGICAL_ERROR, "Logical error:' && + echo "If an exception has LOGICAL_ERROR code, there is no need to include the text 'Logical error' in the exception message, because then the phrase 'Logical error' will be printed twice." +} > "$O.14" 2>&1 & + +# 15: magic_enum and std::format +{ +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) # Don't allow the direct inclusion of magic_enum.hpp and instead point to base/EnumReflection.h find $ROOT_PATH/{src,base,programs,utils} -name '*.cpp' -or -name '*.h' | xargs grep -l "magic_enum.hpp" | grep -v EnumReflection.h | while read -r line; do diff --git a/contrib/arrow-cmake/flight.cmake b/contrib/arrow-cmake/flight.cmake index c028ab8de13c..b2b5d8a49426 100644 --- a/contrib/arrow-cmake/flight.cmake +++ b/contrib/arrow-cmake/flight.cmake @@ -8,27 +8,33 @@ endif() if(NOT ENABLE_GRPC) message(${RECONFIGURE_MESSAGE_LEVEL} "Can't use ArrowFlight without gRPC") + return() endif() set(GRPC_INCLUDE_DIR ${ClickHouse_SOURCE_DIR}/contrib/grpc/include) set(ARROW_FLIGHT_SRC_DIR ${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src/arrow/flight) +set(ARROW_FLIGHT_SQL_SRC_DIR ${ClickHouse_SOURCE_DIR}/contrib/arrow/cpp/src/arrow/flight/sql) set(ARROW_FLIGHT_PROTO_DIR ${ClickHouse_SOURCE_DIR}/contrib/arrow/format) set(ARROW_FLIGHT_GENERATED_SRC_DIR ${ARROW_GENERATED_SRC_DIR}/arrow/flight) +set(ARROW_FLIGHT_SQL_GENERATED_SRC_DIR ${ARROW_GENERATED_SRC_DIR}/arrow/flight/sql) + +set(PROTOBUF_IMPORT_DIRS ${ARROW_FLIGHT_PROTO_DIR} ${ClickHouse_SOURCE_DIR}/contrib/google-protobuf/src) + +PROTOBUF_GENERATE_GRPC_CPP( + flight_sources + flight_headers + APPEND_PATH + PROTOC_OUT_DIR ${ARROW_FLIGHT_GENERATED_SRC_DIR} + ${ARROW_FLIGHT_PROTO_DIR}/Flight.proto +) -add_custom_command( - OUTPUT - "${ARROW_FLIGHT_GENERATED_SRC_DIR}/Flight.grpc.pb.cc" - "${ARROW_FLIGHT_GENERATED_SRC_DIR}/Flight.grpc.pb.h" - "${ARROW_FLIGHT_GENERATED_SRC_DIR}/Flight.pb.cc" - "${ARROW_FLIGHT_GENERATED_SRC_DIR}/Flight.pb.h" - COMMAND ${PROTOBUF_EXECUTABLE} - -I ${ARROW_FLIGHT_PROTO_DIR} - -I "${ClickHouse_SOURCE_DIR}/contrib/google-protobuf/src" - --cpp_out="${ARROW_FLIGHT_GENERATED_SRC_DIR}" - --grpc_out="${ARROW_FLIGHT_GENERATED_SRC_DIR}" - --plugin=protoc-gen-grpc="${GRPC_EXECUTABLE}" - "${ARROW_FLIGHT_PROTO_DIR}/Flight.proto" +PROTOBUF_GENERATE_GRPC_CPP( + flight_sql_sources + flight_sql_headers + APPEND_PATH + PROTOC_OUT_DIR ${ARROW_FLIGHT_SQL_GENERATED_SRC_DIR} + ${ARROW_FLIGHT_PROTO_DIR}/FlightSql.proto ) # NOTE: we do not compile the ${ARROW_FLIGHT_GENERATED_SRCS} directly, instead @@ -37,6 +43,7 @@ add_custom_command( # protobuf-internal.cc set(ARROW_FLIGHT_SRCS ${ARROW_FLIGHT_GENERATED_SRC_DIR}/Flight.pb.cc + ${ARROW_FLIGHT_SQL_GENERATED_SRC_DIR}/FlightSql.pb.cc ${ARROW_FLIGHT_SRC_DIR}/client.cc ${ARROW_FLIGHT_SRC_DIR}/client_cookie_middleware.cc ${ARROW_FLIGHT_SRC_DIR}/client_tracing_middleware.cc @@ -54,6 +61,12 @@ set(ARROW_FLIGHT_SRCS ${ARROW_FLIGHT_SRC_DIR}/transport/grpc/serialization_internal.cc ${ARROW_FLIGHT_SRC_DIR}/transport/grpc/util_internal.cc ${ARROW_FLIGHT_SRC_DIR}/types.cc + ${ARROW_FLIGHT_SQL_SRC_DIR}/client.cc + ${ARROW_FLIGHT_SQL_SRC_DIR}/column_metadata.cc + ${ARROW_FLIGHT_SQL_SRC_DIR}/protocol_internal.cc + ${ARROW_FLIGHT_SQL_SRC_DIR}/server_session_middleware.cc + ${ARROW_FLIGHT_SQL_SRC_DIR}/server.cc + ${ARROW_FLIGHT_SQL_SRC_DIR}/sql_info_internal.cc ) add_library(_arrow_flight ${ARROW_FLIGHT_SRCS}) @@ -62,4 +75,4 @@ add_library(ch_contrib::arrow_flight ALIAS _arrow_flight) add_dependencies(_arrow_flight _protoc grpc_cpp_plugin) target_link_libraries(_arrow_flight PUBLIC _arrow) target_link_libraries(_arrow_flight PRIVATE _protobuf grpc++) -target_include_directories(_arrow_flight PRIVATE ${ARROW_GENERATED_SRC_DIR}) +target_include_directories(_arrow_flight PUBLIC ${ARROW_GENERATED_SRC_DIR}) diff --git a/contrib/grpc-cmake/protobuf_generate_grpc.cmake b/contrib/grpc-cmake/protobuf_generate_grpc.cmake index 4b189301724f..da44d3fc8a80 100644 --- a/contrib/grpc-cmake/protobuf_generate_grpc.cmake +++ b/contrib/grpc-cmake/protobuf_generate_grpc.cmake @@ -9,6 +9,8 @@ protobuf_generate_grpc_cpp( Variable to define with autogenerated header files ``DESCRIPTORS`` Variable to define with autogenerated descriptor files, if requested. +``PROTOC_OUT_DIR`` + Output directory for generated sources ``EXPORT_MACRO`` is a macro which should expand to ``__declspec(dllexport)`` or ``__declspec(dllimport)`` depending on what is being compiled. @@ -31,7 +33,7 @@ function(PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS) set(NATIVE_protoc $) endif() - cmake_parse_arguments(protobuf_generate_grpc_cpp "" "EXPORT_MACRO;DESCRIPTORS" "" ${ARGN}) + cmake_parse_arguments(protobuf_generate_grpc_cpp "APPEND_PATH" "EXPORT_MACRO;DESCRIPTORS;PROTOC_OUT_DIR" "" ${ARGN}) set(_proto_files "${protobuf_generate_grpc_cpp_UNPARSED_ARGUMENTS}") if(NOT _proto_files) @@ -39,7 +41,7 @@ function(PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS) return() endif() - if(PROTOBUF_GENERATE_GRPC_CPP_APPEND_PATH) + if(protobuf_generate_grpc_cpp_APPEND_PATH) set(_append_arg APPEND_PATH) endif() @@ -47,6 +49,11 @@ function(PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS) set(_descriptors DESCRIPTORS) endif() + set(_protoc_out_dir_arg) + if(protobuf_generate_grpc_cpp_PROTOC_OUT_DIR) + set(_protoc_out_dir_arg PROTOC_OUT_DIR ${protobuf_generate_grpc_cpp_PROTOC_OUT_DIR}) + endif() + if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS) set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}") endif() @@ -56,7 +63,7 @@ function(PROTOBUF_GENERATE_GRPC_CPP SRCS HDRS) endif() set(_outvar) - protobuf_generate_grpc(${_append_arg} ${_descriptors} LANGUAGE cpp EXPORT_MACRO ${protobuf_generate_cpp_EXPORT_MACRO} OUT_VAR _outvar ${_import_arg} PROTOS ${_proto_files}) + protobuf_generate_grpc(${_append_arg} ${_descriptors} LANGUAGE cpp EXPORT_MACRO ${protobuf_generate_grpc_cpp_EXPORT_MACRO} OUT_VAR _outvar ${_import_arg} ${_protoc_out_dir_arg} PROTOS ${_proto_files}) set(${SRCS}) set(${HDRS}) diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index 774c6d9f479b..cc8c22f45249 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -123,7 +123,7 @@ #include #include #include -#include +#include #include #include @@ -3473,7 +3473,7 @@ void Server::createServers( listen_host, port_name, "Arrow Flight compatibility protocol: " + address.toString(), - std::unique_ptr(new ArrowFlightHandler(*this, makeSocketAddress(listen_host, port, &logger()))), + std::unique_ptr(new ArrowFlightServer(*this, makeSocketAddress(listen_host, port, &logger()))), true); }); } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 10e3d049ac52..096f81f4bb75 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -331,6 +331,9 @@ endif() if (TARGET ch_contrib::ssh) add_object_library(clickhouse_server_ssh Server/SSH) endif() +if (TARGET ch_contrib::arrow_flight) + add_object_library(clickhouse_server_arrowflight Server/ArrowFlight) +endif() add_object_library(clickhouse_server_embedded_client Server/ClientEmbedded) add_object_library(clickhouse_formats Formats) add_object_library(clickhouse_processors Processors) diff --git a/src/Core/FormatFactorySettings.h b/src/Core/FormatFactorySettings.h index 499c060d885a..c51cf883b2f2 100644 --- a/src/Core/FormatFactorySettings.h +++ b/src/Core/FormatFactorySettings.h @@ -1207,6 +1207,9 @@ Write enum using parquet physical type: BYTE_ARRAY and logical type: ENUM )", 0) \ DECLARE(Bool, output_format_parquet_write_checksums, true, R"( Put crc32 checksums in parquet page headers. +)", 0) \ + DECLARE(Bool, output_format_parquet_unsupported_types_as_binary, false, R"( +Output types having no conversion as raw binary data. If false - such types would raise UNKNOWN_TYPE exception. )", 0) \ DECLARE(String, output_format_avro_codec, "", R"( Compression codec used for output. Possible values: 'null', 'deflate', 'snappy', 'zstd'. @@ -1407,6 +1410,9 @@ Compression method for Arrow output format. Supported codecs: lz4_frame, zstd, n )", 0) \ DECLARE(Bool, output_format_arrow_date_as_uint16, false, R"( Write Date values as plain 16-bit numbers (read back as UInt16), instead of converting to a 32-bit Arrow DATE32 type (read back as Date32). +)", 0) \ + DECLARE(Bool, output_format_arrow_unsupported_types_as_binary, true, R"( +Output types having no conversion as raw binary data. If false - such types would raise UNKNOWN_TYPE exception. )", 0) \ \ DECLARE(Bool, output_format_orc_string_as_string, true, R"( diff --git a/src/Core/SettingsChangesHistory.cpp b/src/Core/SettingsChangesHistory.cpp index d785ceaf7f68..2fb63c9ef2c3 100644 --- a/src/Core/SettingsChangesHistory.cpp +++ b/src/Core/SettingsChangesHistory.cpp @@ -42,6 +42,28 @@ const VersionToSettingsChangesMap & getSettingsChangesHistory() addSettingsChanges(settings_changes_history, "26.3.1.20001.altinityantalya", { {"object_storage_cluster_join_mode", "allow", "allow", "New setting"}, + {"output_format_arrow_unsupported_types_as_binary", false, true, "New setting to convert unsupported CH types to arrow binary instead of UNKNOWN_TYPE exception."}, + {"output_format_parquet_unsupported_types_as_binary", false, false, "New setting to convert unsupported CH types to parquet (arrow) binary instead of UNKNOWN_TYPE exception."}, + {"asterisk_include_virtual_columns", false, false, "New setting"}, + {"max_wkb_geometry_elements", 1'000'000, 1'000'000, "New setting to limit element counts in WKB geometry parsing, preventing excessive memory allocation on malformed data."}, + {"max_rand_distribution_trials", 1'000'000'000, 1'000'000'000, "New setting to limit trial counts in random distribution functions, preventing hangs with extreme inputs."}, + {"max_rand_distribution_parameter", 1e6, 1e6, "New setting to limit shape parameters in random distribution functions, preventing hangs with extreme inputs."}, + {"optimize_truncate_order_by_after_group_by_keys", false, true, "Remove trailing ORDER BY elements once all GROUP BY keys are covered in the ORDER BY prefix."}, + {"use_statistics_for_part_pruning", false, true, "New setting to use statistics for part pruning during query execution."}, + {"distributed_index_analysis_only_on_coordinator", false, false, "New setting."}, + {"query_plan_optimize_join_order_randomize", 0, 0, "New setting to randomize join order statistics for testing."}, + {"enable_materialized_cte", false, false, "New setting"}, + {"use_strict_insert_block_limits", false, false, "New setting to use strict min and max insert bounds on inserts. When min < max, max limits take precedence."}, + {"finalize_projection_parts_synchronously", false, false, "New setting to finalize projection parts synchronously during INSERT to reduce peak memory usage."}, + {"read_in_order_use_virtual_row_per_block", false, false, "Emit virtual row after each block during read-in-order to allow more frequent source reprioritization in MergingSortedTransform."}, + {"distributed_plan_prefer_replicas_over_workers", false, false, "New setting to serialize distributed plan for replicas"}, + {"use_text_index_like_evaluation_by_dictionary_scan", true, true, "New setting"}, + {"text_index_like_min_pattern_length", 4, 4, "New setting"}, + {"text_index_like_max_postings_to_read", 50, 50, "New setting"}, + {"analyzer_inline_views", false, false, "New setting"}, + {"highlight_max_matches_per_row", 10000, 10000, "New setting to limit the number of highlight matches per row to protect against excessive memory usage."}, + {"materialize_statistics_on_insert", true, false, "Disable building statistics on INSERT by default, rely on merges instead"}, + {"enable_join_transitive_predicates", false, false, "New setting to infer transitive equi-join predicates for join order optimization."}, }); addSettingsChanges(settings_changes_history, "26.3", { diff --git a/src/Formats/FormatFactory.cpp b/src/Formats/FormatFactory.cpp index 807bf03ce4f7..2da99a29757f 100644 --- a/src/Formats/FormatFactory.cpp +++ b/src/Formats/FormatFactory.cpp @@ -221,6 +221,7 @@ FormatSettings getFormatSettings(const ContextPtr & context, const Settings & se format_settings.parquet.max_dictionary_size = settings[Setting::output_format_parquet_max_dictionary_size]; format_settings.parquet.output_enum_as_byte_array = settings[Setting::output_format_parquet_enum_as_byte_array]; format_settings.parquet.write_checksums = settings[Setting::output_format_parquet_write_checksums]; + format_settings.parquet.output_unsupported_types_as_binary = settings[Setting::output_format_parquet_unsupported_types_as_binary]; format_settings.parquet.max_block_size = settings[Setting::input_format_parquet_max_block_size]; format_settings.parquet.prefer_block_bytes = settings[Setting::input_format_parquet_prefer_block_bytes]; format_settings.parquet.output_compression_method = settings[Setting::output_format_parquet_compression_method]; @@ -313,6 +314,7 @@ FormatSettings getFormatSettings(const ContextPtr & context, const Settings & se format_settings.arrow.output_fixed_string_as_fixed_byte_array = settings[Setting::output_format_arrow_fixed_string_as_fixed_byte_array]; format_settings.arrow.output_compression_method = settings[Setting::output_format_arrow_compression_method]; format_settings.arrow.output_date_as_uint16 = settings[Setting::output_format_arrow_date_as_uint16]; + format_settings.arrow.output_unsupported_types_as_binary = settings[Setting::output_format_arrow_unsupported_types_as_binary]; format_settings.orc.allow_missing_columns = settings[Setting::input_format_orc_allow_missing_columns]; format_settings.orc.row_batch_size = settings[Setting::input_format_orc_row_batch_size]; format_settings.orc.skip_columns_with_unsupported_types_in_schema_inference = settings[Setting::input_format_orc_skip_columns_with_unsupported_types_in_schema_inference]; diff --git a/src/Formats/FormatSettings.h b/src/Formats/FormatSettings.h index 8cc2792c21f8..ab17c9e77d70 100644 --- a/src/Formats/FormatSettings.h +++ b/src/Formats/FormatSettings.h @@ -166,6 +166,7 @@ struct FormatSettings bool output_fixed_string_as_fixed_byte_array = true; ArrowCompression output_compression_method = ArrowCompression::NONE; bool output_date_as_uint16 = false; + bool output_unsupported_types_as_binary = true; } arrow{}; struct @@ -348,6 +349,7 @@ struct FormatSettings bool allow_geoparquet_parser = true; bool write_geometadata = true; size_t max_dictionary_size = 1024 * 1024; + bool output_unsupported_types_as_binary = false; } parquet{}; struct Pretty diff --git a/src/Processors/Formats/Impl/ArrowBlockOutputFormat.cpp b/src/Processors/Formats/Impl/ArrowBlockOutputFormat.cpp index 1a71a4a32447..cc56309ac056 100644 --- a/src/Processors/Formats/Impl/ArrowBlockOutputFormat.cpp +++ b/src/Processors/Formats/Impl/ArrowBlockOutputFormat.cpp @@ -64,6 +64,7 @@ void ArrowBlockOutputFormat::consume(Chunk chunk) .use_signed_indexes_for_dictionary = format_settings.arrow.use_signed_indexes_for_dictionary, .use_64_bit_indexes_for_dictionary = format_settings.arrow.use_64_bit_indexes_for_dictionary, .output_date_as_uint16 = format_settings.arrow.output_date_as_uint16, + .output_unsupported_types_as_binary = format_settings.arrow.output_unsupported_types_as_binary, }); } diff --git a/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp b/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp index 5acfa96721c6..956699c4c0e2 100644 --- a/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp +++ b/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp @@ -3,6 +3,7 @@ #if USE_ARROW || USE_PARQUET #include +#include #include #include #include @@ -10,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -20,6 +22,13 @@ #include #include #include +<<<<<<< HEAD +======= +#include +#include +#include +#include +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) #include #include #include @@ -28,6 +37,21 @@ #include #include #include +<<<<<<< HEAD +======= +#include +#include +#include +#include +#include +#include +#include +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) + +#include +#include +#include +#include #define FOR_INTERNAL_NUMERIC_TYPES(M) \ M(Int8, arrow::Int8Builder) \ @@ -62,10 +86,10 @@ namespace DB namespace ErrorCodes { extern const int UNKNOWN_EXCEPTION; - extern const int UNKNOWN_TYPE; extern const int LOGICAL_ERROR; extern const int DECIMAL_OVERFLOW; extern const int ILLEGAL_COLUMN; + extern const int UNKNOWN_TYPE; } static const std::initializer_list>> internal_type_to_arrow_type = @@ -103,6 +127,55 @@ namespace DB throw Exception(ErrorCodes::UNKNOWN_EXCEPTION, "Error with a {} column \"{}\": {}.", format_name, column_name, status.ToString()); } + template + static ResultType checkResult(arrow::Result && result, const String & column_name, const String & format_name) + { + checkStatus(result.status(), column_name, format_name); + return std::move(result).ValueUnsafe(); + } + + static std::shared_ptr nullBytemapToArrowBitmap( + const PaddedPODArray * null_bytemap, + const String & column_name, + const String & format_name, + size_t start, + size_t end) + { + if (!null_bytemap) + return nullptr; + + int64_t length = static_cast(end - start); + auto bitmap = checkResult(arrow::AllocateEmptyBitmap(length), column_name, format_name); + auto * data = bitmap->mutable_data(); + for (size_t i = 0; i < static_cast(length); ++i) + { + if (!(*null_bytemap)[start + i]) + arrow::bit_util::SetBit(data, static_cast(i)); + } + return bitmap; + } + + static void fillArrowArrayWithRawColumnData( + ColumnPtr write_column, + const PaddedPODArray * null_bytemap, + const String & format_name, + arrow::ArrayBuilder* array_builder, + size_t start, + size_t end) + { + arrow::BinaryBuilder & builder = assert_cast(*array_builder); + arrow::Status status; + + for (size_t value_i = start; value_i < end; ++value_i) + { + if (null_bytemap && (*null_bytemap)[value_i]) + status = builder.AppendNull(); + else + status = builder.Append(write_column->getDataAt(value_i)); + checkStatus(status, write_column->getName(), format_name); + } + } + /// Invert values since Arrow interprets 1 as a non-null value, while CH as a null static PaddedPODArray revertNullByteMap(const PaddedPODArray * null_bytemap, size_t start, size_t end) { @@ -231,9 +304,57 @@ namespace DB } } +<<<<<<< HEAD static void fillArrowArray( +======= + static void fillArrowArrayWithUUIDColumnData( + const ColumnPtr & column, + const PaddedPODArray * null_bytemap, + const String & format_name, + arrow::ArrayBuilder * array_builder, + size_t start, + size_t end) + { + const auto * col_uuid = assert_cast *>(column.get()); + + if (array_builder->type()->id() != arrow::Type::FIXED_SIZE_BINARY) + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Cannot fill arrow array with {} data for format {}", column->getName(), format_name); + + auto * fixed_builder = assert_cast(array_builder); + const auto & uuid_data = col_uuid->getData(); + + for (size_t i = start; i < end; ++i) + { + if (null_bytemap && (*null_bytemap)[i]) + { + arrow::Status status = fixed_builder->AppendNull(); + checkStatus(status, column->getName(), format_name); + continue; + } + + UUID res = uuid_data[i]; + auto * bytes = reinterpret_cast(&res); + + if constexpr (std::endian::native == std::endian::little) + { + std::reverse(bytes, bytes + 8); + std::reverse(bytes + 8, bytes + 16); + } + else + { + std::swap_ranges(bytes, bytes + 8, bytes + 8); + } + + arrow::Status status = fixed_builder->Append(reinterpret_cast(&res)); + checkStatus(status, column->getName(), format_name); + } + } + + static std::shared_ptr fillArrowArray( +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) const String & column_name, - ColumnPtr & column, + ColumnPtr column, const DataTypePtr & column_type, const PaddedPODArray * null_bytemap, arrow::ArrayBuilder * array_builder, @@ -243,44 +364,166 @@ namespace DB const CHColumnToArrowColumn::Settings & settings, std::unordered_map & dictionary_values); - template - static void fillArrowArrayWithArrayColumnData( - const String & column_name, - ColumnPtr & column, - const DataTypePtr & column_type, - const PaddedPODArray *, - arrow::ArrayBuilder * array_builder, - String format_name, + + static std::shared_ptr getArrowType( + DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable, bool for_builder = false); + + + static std::shared_ptr buildArrowDenseUnionArrayWithVariantColumnData( + const ColumnVariant & column, + const DataTypeVariant & column_type, + const PaddedPODArray * null_bytemap, + const String & format_name, size_t start, size_t end, const CHColumnToArrowColumn::Settings & settings, std::unordered_map & dictionary_values) { - const auto * column_array = assert_cast(column.get()); - ColumnPtr nested_column = column_array->getDataPtr(); - DataTypePtr nested_type = assert_cast(column_type.get())->getNestedType(); - const auto & offsets = column_array->getOffsets(); + size_t size = end - start; + const auto & column_offsets = column.getOffsets(); + const auto & discriminators = column.getLocalDiscriminators(); + arrow::Int8Builder type_ids_builder; + + const auto num_variants = column.getNumVariants(); + if (num_variants > static_cast(std::numeric_limits::max())) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Cannot convert Variant with {} nested types to {} Arrow DenseUnion: maximum supported is {} ", + num_variants, + format_name, + static_cast(std::numeric_limits::max())); + + std::vector starts(num_variants); + std::vector ends(num_variants); + arrow::Status status; + /// Here we are doing slicing - there is no clear specification on ColumnVariant having + /// offsets being monotonic and contiguous (though from current code it seems they are), + /// Arrow DenseUnion explicitly requires monotonicity, so we are going to tolerate non-contiguous + /// offsets, but raise an exception for violation of monotonicity. + for (size_t idx = start; idx < discriminators.size() && idx < end; ++idx) + { + const auto & discriminator = discriminators[idx]; + if (discriminator != ColumnVariant::NULL_DISCRIMINATOR) + { + auto global_discr = column.globalDiscriminatorByLocal(discriminator); + if (ends[global_discr] == 0) + starts[global_discr] = column_offsets[idx]; + else if (column_offsets[idx] < ends[global_discr]) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Cannot convert Variant to {} Arrow DenseUnion: " + "variant offsets are not monotonic for discriminator {}", + format_name, std::to_string(global_discr)); + ends[global_discr] = column_offsets[idx] + 1; + } + + if (discriminator == ColumnVariant::NULL_DISCRIMINATOR || (null_bytemap && (*null_bytemap)[idx])) + status = type_ids_builder.Append(static_cast(num_variants)); + else + status = type_ids_builder.Append(static_cast(column.globalDiscriminatorByLocal(discriminator))); + + checkStatus(status, "type_ids", format_name); + } + + std::shared_ptr type_ids_array; + status = type_ids_builder.Finish(&type_ids_array); + checkStatus(status, "type_ids", format_name); - Builder & builder = assert_cast(*array_builder); - arrow::ArrayBuilder * value_builder = builder.value_builder(); - arrow::Status components_status; - for (size_t array_idx = start; array_idx < end; ++array_idx) + arrow::ArrayVector children; + for (size_t i = 0; i < column.getNumVariants(); ++i) { - /// Start new array. - components_status = builder.Append(); - checkStatus(components_status, nested_column->getName(), format_name); - - /// Pass null null_map, because fillArrowArray will decide whether nested_type is nullable, if nullable, it will create a new null_map from nested_column - /// Note that it is only needed by gluten(https://github.com/oap-project/gluten), because array type in gluten is by default nullable. - /// And it does not influence the original ClickHouse logic, because null_map passed to fillArrowArrayWithArrayColumnData is always nullptr for ClickHouse doesn't allow nullable complex types including array type. - fillArrowArray(column_name, nested_column, nested_type, nullptr, value_builder, format_name, offsets[array_idx - 1], offsets[array_idx], settings, dictionary_values); + const auto & variant = column.getVariantPtrByGlobalDiscriminator(i); + + bool is_column_nullable = false; + auto arrow_type = getArrowType( + column_type.getVariant(i), + variant, + variant->getName(), + format_name, + settings, + &is_column_nullable); + + std::unique_ptr variant_array_builder; + status = MakeBuilder(arrow::default_memory_pool(), arrow_type, &variant_array_builder); + checkStatus(status, variant->getName(), format_name); + + if (ends[i] == 0) + { + auto empty_array = checkResult(arrow::MakeArrayOfNull(arrow_type, 0), variant->getName(), format_name); + children.push_back(empty_array); + } + else + { + std::shared_ptr variant_arrow_array = fillArrowArray( + variant->getName(), + variant, + column_type.getVariant(i), + nullptr, + variant_array_builder.get(), + format_name, + starts[i], + ends[i], + settings, + dictionary_values); + + children.push_back(variant_arrow_array); + } } + children.push_back(std::make_shared(1)); + + arrow::Int32Builder offsets_builder; + + /// column_offsets should be sanitized because NULL_DISCRIMINATOR positions in ColumnVariant + /// makes offsets at these positions irrelevant (and they can have unspecified values), + /// but for arrow dense union they are pointing to an actual NULL array + auto to_arrow_offset = [&](const auto & tuple) -> int32_t + { + const auto & discriminator = boost::get<0>(tuple); + const auto & column_offset = boost::get<1>(tuple); + + if constexpr (std::tuple_size_v> == 3) + if (static_cast(boost::get<2>(tuple))) + return 0; + if (discriminator == ColumnVariant::NULL_DISCRIMINATOR) + return 0; + + const auto offset = column_offset - starts[column.globalDiscriminatorByLocal(discriminator)]; + if (offset > static_cast(std::numeric_limits::max())) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot build Arrow DenseUnion: offset {} is out of Int32 range", offset); + return static_cast(offset); + }; + + auto append_offsets = [&](Ts&&... args) + { + auto begin_it = boost::make_transform_iterator( + boost::make_zip_iterator(boost::make_tuple((args->begin() + start)...)), + to_arrow_offset + ); + auto end_it = boost::make_transform_iterator( + boost::make_zip_iterator(boost::make_tuple((args->begin() + start + size)...)), + to_arrow_offset + ); + return offsets_builder.AppendValues(begin_it, end_it); + }; + + if (null_bytemap) + status = append_offsets(&discriminators, &column_offsets, null_bytemap); + else + status = append_offsets(&discriminators, &column_offsets); + + checkStatus(status, "offsets", format_name); + std::shared_ptr offsets_array; + status = offsets_builder.Finish(&offsets_array); + checkStatus(status, "offsets", format_name); + + return checkResult(arrow::DenseUnionArray::Make(*type_ids_array, *offsets_array, children), "type_ids", format_name); } - static void fillArrowArrayWithTupleColumnData( + + static std::shared_ptr buildArrowStructArrayWithTupleColumnData( const String & column_name, - ColumnPtr & column, + const ColumnPtr & column, const DataTypePtr & column_type, const PaddedPODArray * null_bytemap, arrow::ArrayBuilder * array_builder, @@ -297,51 +540,84 @@ namespace DB arrow::StructBuilder & builder = assert_cast(*array_builder); + if (column_tuple->tupleSize() == 0) + { + for (size_t i = start; i != end; ++i) + checkStatus(builder.Append(), column->getName(), format_name); + return checkResult(builder.Finish(), column_name, format_name); + } + + arrow::ArrayVector children; + for (size_t i = 0; i != column_tuple->tupleSize(); ++i) { ColumnPtr nested_column = column_tuple->getColumnPtr(i); - fillArrowArray( - column_name + "." + nested_names[i], + auto name = column_name + "." + nested_names[i]; + std::shared_ptr nested_arrow_array = fillArrowArray( + name, nested_column, nested_types[i], null_bytemap, builder.field_builder(static_cast(i)), format_name, start, end, settings, dictionary_values); - } - for (size_t i = start; i != end; ++i) - { - auto status = builder.Append(); - checkStatus(status, column->getName(), format_name); + children.push_back(nested_arrow_array); } + + auto null_bitmap = nullBytemapToArrowBitmap(null_bytemap, column_name, format_name, start, end); + return checkResult(arrow::StructArray::Make(children, builder.type()->fields(), null_bitmap), column_name, format_name); } - template - static PaddedPODArray extractIndexes(ColumnPtr column, size_t start, size_t end, bool shift) + template + requires (std::integral && std::integral) + static PaddedPODArray extractIndexes(ColumnPtr column, size_t start, size_t end, bool shift) { - const PaddedPODArray & data = assert_cast *>(column.get())->getData(); - PaddedPODArray result; + const PaddedPODArray & data = assert_cast *>(column.get())->getData(); + PaddedPODArray result; result.reserve(end - start); + + auto checked_cast = [](From value) -> To + { + constexpr bool always_safe = + // same signedness, destination has at least as many value bits + (std::numeric_limits::is_signed == std::numeric_limits::is_signed + && std::numeric_limits::digits >= std::numeric_limits::digits) + // unsigned -> signed is safe only if destination has strictly more value bits + || (!std::numeric_limits::is_signed + && std::numeric_limits::is_signed + && std::numeric_limits::digits > std::numeric_limits::digits); + + if constexpr (always_safe) + return static_cast(value); + + To converted{}; + if (!accurate::convertNumeric(value, converted)) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot convert index {} to target type without overflow", std::to_string(value)); + return converted; + }; + if (shift) - std::transform(data.begin() + start, data.begin() + end, std::back_inserter(result), [](T value) { return Int64(value) - 1; }); + std::transform(data.begin() + start, data.begin() + end, std::back_inserter(result), [&](From value) { return checked_cast(value) - 1; }); else - std::transform(data.begin() + start, data.begin() + end, std::back_inserter(result), [](T value) { return Int64(value); }); + std::transform(data.begin() + start, data.begin() + end, std::back_inserter(result), checked_cast); return result; } - static PaddedPODArray extractIndexes(ColumnPtr column, size_t start, size_t end, bool shift) + template + requires std::integral + static PaddedPODArray extractIndexes(ColumnPtr column, size_t start, size_t end, bool shift) { switch (column->getDataType()) { case TypeIndex::UInt8: - return extractIndexes(column, start, end, shift); + return extractIndexes(column, start, end, shift); case TypeIndex::UInt16: - return extractIndexes(column, start, end, shift); + return extractIndexes(column, start, end, shift); case TypeIndex::UInt32: - return extractIndexes(column, start, end, shift); + return extractIndexes(column, start, end, shift); case TypeIndex::UInt64: - return extractIndexes(column, start, end, shift); + return extractIndexes(column, start, end, shift); default: throw Exception(ErrorCodes::LOGICAL_ERROR, "Indexes column must be ColumnUInt, got {}.", column->getName()); } @@ -399,7 +675,7 @@ namespace DB /// We can use Int32/UInt32/Int64/UInt64 type for indexes. const auto * indexes_int32_type = typeid_cast(dict_indexes_arrow_type.get()); const auto * indexes_uint32_type = typeid_cast(dict_indexes_arrow_type.get()); - const auto * indexes_int64_type = typeid_cast(dict_indexes_arrow_type.get()); + const auto * indexes_int64_type = typeid_cast(dict_indexes_arrow_type.get()); if ((indexes_int32_type && dict_size > INT32_MAX) || (indexes_uint32_type && dict_size > UINT32_MAX) || (indexes_int64_type && dict_size > INT64_MAX)) throw Exception( ErrorCodes::ILLEGAL_COLUMN, @@ -407,10 +683,80 @@ namespace DB " resulting dictionary size exceeds the max value of index type {}", dict_indexes_arrow_type->name()); } + static std::shared_ptr buildArrowListArrayWithArrayColumnData( + const String & column_name, + const ColumnPtr & column, + const DataTypePtr & column_type, + const PaddedPODArray * null_bytemap, + arrow::ArrayBuilder * array_builder, + String format_name, + size_t start, + size_t end, + const CHColumnToArrowColumn::Settings & settings, + std::unordered_map & dictionary_values) + { + const auto * column_array = assert_cast(column.get()); + const auto * type_array = assert_cast(column_type.get()); + + const auto column_offsets = assert_cast(column_array->getOffsetsColumn()).getPtr(); + size_t offsets_start = start > 0 ? start - 1 : 0; + size_t offsets_view_start = start > 0 ? 1 : 0; + auto offsets = extractIndexes(column_offsets, offsets_start, end, false); + size_t values_start = start == 0 ? 0 : offsets[0]; + size_t values_end = offsets.empty() ? values_start : offsets.back(); + + arrow::ListBuilder & builder = assert_cast(*array_builder); + + auto data_array = fillArrowArray(column_name, column_array->getDataPtr(), type_array->getNestedType(), nullptr, builder.value_builder(), format_name, values_start, values_end, settings, dictionary_values); + + arrow::Status status; + arrow::Int32Builder offsets_builder; + status = offsets_builder.Append(0); + checkStatus(status, column_name, format_name); + for (size_t i = offsets_view_start; i < offsets.size(); ++i) + { + status = offsets_builder.Append(static_cast(offsets[i] - values_start)); + checkStatus(status, column_name, format_name); + } + + std::shared_ptr offsets_array; + status = offsets_builder.Finish(&offsets_array); + checkStatus(status, column_name, format_name); + + auto null_bitmap = nullBytemapToArrowBitmap(null_bytemap, column_name, format_name, start, end); + return checkResult(arrow::ListArray::FromArrays(*offsets_array, *data_array, arrow::default_memory_pool(), null_bitmap), column_name, format_name); + } + + static std::shared_ptr buildArrowMapArrayWithMapColumnData( + const String & column_name, + const ColumnPtr & column, + const DataTypePtr & column_type, + const PaddedPODArray * null_bytemap, + arrow::ArrayBuilder * array_builder, + String format_name, + size_t start, + size_t end, + const CHColumnToArrowColumn::Settings & settings, + std::unordered_map & dictionary_values) + { + const auto * column_map = assert_cast(column.get()); + auto nested_column = column_map->getNestedColumnPtr(); + const auto * type_map = assert_cast(column_type.get()); + const DataTypePtr & nested_type = type_map->getNestedType(); + + auto * map_builder = assert_cast(array_builder); + auto builder = checkResult(arrow::MakeBuilder(arrow::list(map_builder->value_builder()->type())), column_name, format_name); + + auto list = buildArrowListArrayWithArrayColumnData(column_name, nested_column, nested_type, null_bytemap, builder.get(), format_name, start, end, settings, dictionary_values); + auto * list_array = assert_cast(list.get()); + + return std::make_shared(map_builder->type(), list_array->length(), list_array->value_offsets(), list_array->values(), list_array->null_bitmap()); + } + template static void fillArrowArrayWithLowCardinalityColumnDataImpl( const String & column_name, - ColumnPtr & column, + const ColumnPtr & column, const DataTypePtr & column_type, const PaddedPODArray *, arrow::ArrayBuilder * array_builder, @@ -458,10 +804,7 @@ namespace DB auto dict_column = dynamic_cast(*dict_values).getNestedNotNullableColumn(); const auto & dict_type = removeNullable(assert_cast(column_type.get())->getDictionaryType()); - fillArrowArray(column_name, dict_column, dict_type, nullptr, values_builder.get(), format_name, is_nullable, dict_column->size(), settings, dictionary_values); - std::shared_ptr arrow_dict_array; - status = values_builder->Finish(&arrow_dict_array); - checkStatus(status, column->getName(), format_name); + std::shared_ptr arrow_dict_array = fillArrowArray(column_name, dict_column, dict_type, nullptr, values_builder.get(), format_name, is_nullable, dict_column->size(), settings, dictionary_values); status = builder->InsertMemoValues(*arrow_dict_array); checkStatus(status, column->getName(), format_name); @@ -492,7 +835,7 @@ namespace DB static void fillArrowArrayWithLowCardinalityColumnData( const String & column_name, - ColumnPtr & column, + const ColumnPtr & column, const DataTypePtr & column_type, const PaddedPODArray * null_bytemap, arrow::ArrayBuilder * array_builder, @@ -798,9 +1141,9 @@ namespace DB checkStatus(status, write_column->getName(), format_name); } - static void fillArrowArray( + static std::shared_ptr fillArrowArray( const String & column_name, - ColumnPtr & column, + ColumnPtr column, const DataTypePtr & column_type, const PaddedPODArray * null_bytemap, arrow::ArrayBuilder * array_builder, @@ -810,6 +1153,11 @@ namespace DB const CHColumnToArrowColumn::Settings & settings, std::unordered_map & dictionary_values) { + std::shared_ptr arrow_array; + + column = column->convertToFullColumnIfConst(); + column = column->convertToFullColumnIfReplicated(); + switch (column_type->getTypeId()) { case TypeIndex::Nullable: @@ -819,12 +1167,12 @@ namespace DB DataTypePtr nested_type = assert_cast(column_type.get())->getNestedType(); const ColumnPtr & null_column = column_nullable->getNullMapColumnPtr(); const PaddedPODArray & bytemap = assert_cast &>(*null_column).getData(); - fillArrowArray(column_name, nested_column, nested_type, &bytemap, array_builder, format_name, start, end, settings, dictionary_values); + arrow_array = fillArrowArray(column_name, nested_column, nested_type, &bytemap, array_builder, format_name, start, end, settings, dictionary_values); break; } case TypeIndex::String: { - if (settings.output_string_as_string) + if (settings.output_string_as_string && !array_builder->type()->Equals(arrow::binary())) fillArrowArrayWithStringColumnData(column, null_bytemap, format_name, array_builder, start, end); else fillArrowArrayWithStringColumnData(column, null_bytemap, format_name, array_builder, start, end); @@ -856,19 +1204,32 @@ namespace DB fillArrowArrayWithDate32ColumnData(column, null_bytemap, format_name, array_builder, start, end); break; case TypeIndex::Array: - fillArrowArrayWithArrayColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); + arrow_array = buildArrowListArrayWithArrayColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); break; case TypeIndex::Tuple: - fillArrowArrayWithTupleColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); + arrow_array = buildArrowStructArrayWithTupleColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); break; case TypeIndex::LowCardinality: fillArrowArrayWithLowCardinalityColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); break; case TypeIndex::Map: { - ColumnPtr column_array = assert_cast(column.get())->getNestedColumnPtr(); - DataTypePtr array_type = assert_cast(column_type.get())->getNestedType(); - fillArrowArrayWithArrayColumnData(column_name, column_array, array_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); + arrow_array = buildArrowMapArrayWithMapColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); + break; + } + case TypeIndex::Variant: + { + const auto & column_variant = assert_cast(*column); + const auto & column_variant_type = assert_cast(*column_type); + arrow_array = buildArrowDenseUnionArrayWithVariantColumnData( + column_variant, + column_variant_type, + null_bytemap, + format_name, + start, + end, + settings, + dictionary_values); break; } case TypeIndex::Decimal32: @@ -919,8 +1280,18 @@ namespace DB FOR_INTERNAL_NUMERIC_TYPES(DISPATCH) #undef DISPATCH default: - throw Exception(ErrorCodes::UNKNOWN_TYPE, "Internal type '{}' of a column '{}' is not supported for conversion into {} data format.", column_type->getFamilyName(), column_name, format_name); + if (!settings.output_unsupported_types_as_binary) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Internal type '{}' of a column '{}' is not supported for conversion into {} data format.", column_type->getFamilyName(), column_name, format_name); + fillArrowArrayWithRawColumnData(column, null_bytemap, format_name, array_builder, start, end); + } + + if (!arrow_array) + { + auto status = array_builder->Finish(&arrow_array); + checkStatus(status, column->getName(), format_name); } + + return arrow_array; } static std::shared_ptr getArrowTypeForLowCardinalityIndexes(ColumnPtr indexes_column, const CHColumnToArrowColumn::Settings & settings) @@ -967,13 +1338,28 @@ namespace DB } static std::shared_ptr getArrowType( +<<<<<<< HEAD DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable) +======= + DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable, bool for_builder) +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) { + if (column) + { + column = column->convertToFullColumnIfConst(); + column = column->convertToFullColumnIfReplicated(); + } + if (column_type->isNullable()) { DataTypePtr nested_type = assert_cast(column_type.get())->getNestedType(); +<<<<<<< HEAD ColumnPtr nested_column = assert_cast(column.get())->getNestedColumnPtr(); auto arrow_type = getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable); +======= + ColumnPtr nested_column = column ? assert_cast(column.get())->getNestedColumnPtr() : nullptr; + auto arrow_type = getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable, for_builder); +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) *out_is_column_nullable = true; return arrow_type; } @@ -1006,7 +1392,7 @@ namespace DB if (isArray(column_type)) { auto nested_type = assert_cast(column_type.get())->getNestedType(); - auto nested_column = assert_cast(column.get())->getDataPtr(); + auto nested_column = column ? assert_cast(column.get())->getDataPtr() : nullptr; bool is_item_nullable = false; auto nested_arrow_type = getArrowType(nested_type, nested_column, column_name, format_name, settings, &is_item_nullable); return arrow::list(std::make_shared("item", nested_arrow_type, is_item_nullable)); @@ -1017,12 +1403,16 @@ namespace DB const auto & tuple_type = assert_cast(column_type.get()); const auto & nested_types = tuple_type->getElements(); const auto & nested_names = tuple_type->getElementNames(); - const auto * tuple_column = assert_cast(column.get()); + const auto * tuple_column = column ? assert_cast(column.get()) : nullptr; std::vector> nested_fields; for (size_t i = 0; i != nested_types.size(); ++i) { bool is_field_nullable = false; +<<<<<<< HEAD auto nested_arrow_type = getArrowType(nested_types[i], tuple_column->getColumnPtr(i), nested_names[i], format_name, settings, &is_field_nullable); +======= + auto nested_arrow_type = getArrowType(nested_types[i], tuple_column ? tuple_column->getColumnPtr(i) : nullptr, nested_names[i], format_name, settings, &is_field_nullable, for_builder); +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) nested_fields.push_back(std::make_shared(nested_names[i], nested_arrow_type, is_field_nullable)); } return arrow::struct_(nested_fields); @@ -1031,12 +1421,32 @@ namespace DB if (column_type->lowCardinality()) { auto nested_type = assert_cast(column_type.get())->getDictionaryType(); +<<<<<<< HEAD const auto * lc_column = assert_cast(column.get()); const auto & nested_column = lc_column->getDictionary().getNestedColumn(); const auto & indexes_column = lc_column->getIndexesPtr(); return arrow::dictionary( getArrowTypeForLowCardinalityIndexes(indexes_column, settings), getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable)); +======= + if (column) + { + const auto * lc_column = assert_cast(column.get()); + const auto & nested_column = lc_column->getDictionary().getNestedColumn(); + const auto & indexes_column = lc_column->getIndexesPtr(); + return arrow::dictionary( + getArrowTypeForLowCardinalityIndexes(indexes_column, settings), + getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable, for_builder)); + } + else + { + auto index_arrow_type = settings.use_64_bit_indexes_for_dictionary ? + (settings.use_signed_indexes_for_dictionary ? arrow::int64() : arrow::uint64()) : + (settings.use_signed_indexes_for_dictionary ? arrow::int32() : arrow::uint32()); + auto arrow_type = getArrowType(nested_type, nullptr, column_name, format_name, settings, out_is_column_nullable, for_builder); + return arrow::dictionary(index_arrow_type, arrow_type); + } +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) } if (isMap(column_type)) @@ -1044,12 +1454,26 @@ namespace DB const auto * map_type = assert_cast(column_type.get()); const auto & key_type = map_type->getKeyType(); const auto & val_type = map_type->getValueType(); - const auto & columns = assert_cast(column.get())->getNestedData().getColumns(); + ColumnPtr key_column; + ColumnPtr value_column; + if (column) + { + const auto & columns = assert_cast(column.get())->getNestedData().getColumns(); + key_column = columns[0]; + value_column = columns[1]; + } +<<<<<<< HEAD bool _is_key_nullable = false; auto key_arrow_type = getArrowType(key_type, columns[0], column_name, format_name, settings, &_is_key_nullable); bool is_val_nullable = false; auto val_arrow_type = getArrowType(val_type, columns[1], column_name, format_name, settings, &is_val_nullable); +======= + bool is_key_nullable = false; + auto key_arrow_type = getArrowType(key_type, key_column, column_name, format_name, settings, &is_key_nullable, for_builder); + bool is_val_nullable = false; + auto val_arrow_type = getArrowType(val_type, value_column, column_name, format_name, settings, &is_val_nullable, for_builder); +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) return arrow::map( key_arrow_type, @@ -1080,6 +1504,69 @@ namespace DB if (isIPv4(column_type)) return arrow::uint32(); +<<<<<<< HEAD +======= + if (isVariant(column_type)) + { + const auto * column_variant = column ? &assert_cast(*column) : nullptr; + const auto & column_variant_type = assert_cast(*column_type); + + auto size = column_variant_type.getVariants().size(); + if (size > static_cast(std::numeric_limits::max())) + { + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Cannot convert Variant with {} nested types to {} Arrow DenseUnion: maximum supported is {} ", + size, + format_name, + static_cast(std::numeric_limits::max())); + } + + arrow::FieldVector fields; + + for (size_t i = 0; i < size; ++i) + { + const auto variant = column_variant ? column_variant->getVariantPtrByGlobalDiscriminator(i) : nullptr; + + bool is_column_nullable = false; + auto arrow_type = getArrowType( + column_variant_type.getVariant(i), + variant, + variant ? variant->getName() : "variant", + format_name, + settings, + &is_column_nullable, + for_builder); + + std::string field_name = column_variant_type.getVariant(i)->getFamilyName(); + fields.push_back(std::make_shared(field_name, arrow_type, is_column_nullable)); + } + + /// Variant in CH is slightly different than in arrow - it can indicate null value by having ColumnVariant::NULL_DISCRIMINATOR + /// in discriminators instead of using nullable type - because of this we need to introduce additional + /// null array (having a single null value) to have these null values to refer to + fields.push_back(std::make_shared("NULL", arrow::null(), false)); + + return arrow::dense_union(fields); + } + + if (isInterval(column_type)) + { + const auto * interval_type = assert_cast(column_type.get()); + switch (interval_type->getKind()) + { + case IntervalKind::Kind::Nanosecond: return arrow::duration(arrow::TimeUnit::NANO); + case IntervalKind::Kind::Microsecond: return arrow::duration(arrow::TimeUnit::MICRO); + case IntervalKind::Kind::Millisecond: return arrow::duration(arrow::TimeUnit::MILLI); + case IntervalKind::Kind::Second: return arrow::duration(arrow::TimeUnit::SECOND); + default: return arrow::int64(); + } + } + + if (isUUID(column_type)) + return for_builder ? arrow::fixed_size_binary(sizeof(UUID)) : std::make_shared(); + +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) if (isDate(column_type) && settings.output_date_as_uint16) return arrow::uint16(); @@ -1093,9 +1580,146 @@ namespace DB return arrow_type_it->second; } - throw Exception(ErrorCodes::UNKNOWN_TYPE, - "The type '{}' of a column '{}' is not supported for conversion into {} data format.", - column_type->getName(), column_name, format_name); + if (!settings.output_unsupported_types_as_binary) + throw Exception(ErrorCodes::UNKNOWN_TYPE, + "The type '{}' of a column '{}' is not supported for conversion into {} data format.", + column_type->getName(), column_name, format_name); + return arrow::binary(); + } + + std::shared_ptr CHColumnToArrowColumn::calculateArrowSchema( + const ColumnsWithTypeAndName & header_columns, + const std::string & format_name, + const Chunk * chunk, + const Settings & settings, + std::optional columns_num, + const std::optional> & column_to_field_id + ) + { + if (!columns_num) + columns_num = header_columns.size(); + + std::vector> arrow_fields; + arrow_fields.reserve(*columns_num); + + for (size_t column_i = 0; column_i < *columns_num; ++column_i) + { + const ColumnWithTypeAndName & header_column = header_columns[column_i]; + auto column_type = header_column.type; + auto column = chunk ? chunk->getColumns()[column_i] : header_column.column; + + if (!settings.low_cardinality_as_dictionary) + { + column_type = recursiveRemoveLowCardinality(column_type); + if (column) + column = recursiveRemoveLowCardinality(column); + } + + bool is_column_nullable = false; + auto arrow_type = getArrowType( + column_type, + column, + header_column.name, + format_name, + settings, + &is_column_nullable); + + std::shared_ptr field_metadata = nullptr; + + if (column_to_field_id && column_to_field_id->contains(header_column.name)) + { + Int64 field_id = column_to_field_id->at(header_column.name); + field_metadata = arrow::key_value_metadata({"PARQUET:field_id"}, {std::to_string(field_id)}); + } + + // Inject our UUID metadata if it's a root UUID column + if (isUUID(removeNullable(header_column.type))) + { + auto ext_metadata = arrow::key_value_metadata( + {"ARROW:extension:name", "ARROW:extension:metadata", "PARQUET:logical_type"}, + {"arrow.uuid", "", "UUID"} + ); + field_metadata = field_metadata ? field_metadata->Merge(*ext_metadata) : ext_metadata; + } + + if (field_metadata) + arrow_fields.emplace_back(std::make_shared(header_column.name, arrow_type, is_column_nullable, field_metadata)); + else + arrow_fields.emplace_back(std::make_shared(header_column.name, arrow_type, is_column_nullable)); + } + + return std::make_shared(arrow_fields); + } + + + std::shared_ptr CHColumnToArrowColumn::calculateArrowTable( + const ColumnsWithTypeAndName & header_columns, + const std::string & format_name, + const std::vector & chunks, + const Settings & settings, + size_t columns_num, + std::shared_ptr schema, + std::unordered_map * cached_dictionary_values) + { + /// Map {column name : arrow dictionary}. + /// To avoid converting dictionary from LowCardinality to Arrow + /// Dictionary every chunk we save it and reuse. + std::unordered_map local_dictionary_values; + std::unordered_map & dictionary_values = cached_dictionary_values ? *cached_dictionary_values : local_dictionary_values; + + std::vector table_data(columns_num); + + for (const auto & chunk : chunks) + { + /// For arrow::Table creation + for (size_t column_i = 0; column_i < columns_num; ++column_i) + { + const ColumnWithTypeAndName & header_column = header_columns[column_i]; + auto column_type = header_column.type; + auto column = chunk.getColumns()[column_i]; + + if (!settings.low_cardinality_as_dictionary) + { + column = recursiveRemoveLowCardinality(column); + column_type = recursiveRemoveLowCardinality(column_type); + } + + // Generate the unwrapped builder schema (safe for MakeBuilder) + bool is_column_nullable = false; + auto builder_type = getArrowType( + column_type, column, header_column.name, format_name, settings, &is_column_nullable, true /* for_builder */); + + std::unique_ptr array_builder; + arrow::Status status = MakeBuilder(arrow::default_memory_pool(), builder_type, &array_builder); + checkStatus(status, column->getName(), format_name); + + std::shared_ptr arrow_array = fillArrowArray( + header_column.name, + column, + column_type, + nullptr, + array_builder.get(), + format_name, + 0, + column->size(), + settings, + dictionary_values); + + // Zero-copy cast to the extension-rich schema (handles infinite nesting) + auto target_type = schema->field(static_cast(column_i))->type(); + if (!arrow_array->type()->Equals(*target_type)) + arrow_array = checkResult(arrow_array->View(target_type), column->getName(), format_name); + + table_data.at(column_i).emplace_back(std::move(arrow_array)); + } + } + + std::vector> columns; + columns.reserve(columns_num); + for (size_t column_i = 0; column_i < columns_num; ++column_i) + columns.emplace_back(std::make_shared(table_data.at(column_i))); + + return arrow::Table::Make(schema, columns); } CHColumnToArrowColumn::CHColumnToArrowColumn(const Block & header, const std::string & format_name_, const Settings & settings_) @@ -1135,6 +1759,7 @@ namespace DB { if (arrow_schema) return; +<<<<<<< HEAD if (!columns_num) columns_num = header_columns.size(); @@ -1170,6 +1795,9 @@ namespace DB } arrow_schema = std::make_shared(arrow_fields); +======= + arrow_schema = calculateArrowSchema(header_columns, format_name, chunk, settings, columns_num, column_to_field_id); +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) } std::shared_ptr CHColumnToArrowColumn::getArrowSchema() const @@ -1185,12 +1813,11 @@ namespace DB size_t columns_num, const std::optional> & column_to_field_id) { - std::vector table_data(columns_num); - /// We use the first chunk to initialize the arrow schema. const Chunk * chunk_to_initialize_schema = chunks.empty() ? nullptr : chunks.data(); initializeArrowSchema(chunk_to_initialize_schema, columns_num, column_to_field_id); +<<<<<<< HEAD for (const auto & chunk : chunks) { /// For arrow::Table creation @@ -1232,6 +1859,9 @@ namespace DB columns.emplace_back(std::make_shared(table_data.at(column_i))); res = arrow::Table::Make(arrow_schema, columns); +======= + res = calculateArrowTable(header_columns, format_name, chunks, settings, columns_num, arrow_schema, &dictionary_values); +>>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) } } diff --git a/src/Processors/Formats/Impl/CHColumnToArrowColumn.h b/src/Processors/Formats/Impl/CHColumnToArrowColumn.h index 2544faf9ed1d..d3a3094e021b 100644 --- a/src/Processors/Formats/Impl/CHColumnToArrowColumn.h +++ b/src/Processors/Formats/Impl/CHColumnToArrowColumn.h @@ -33,12 +33,35 @@ class CHColumnToArrowColumn bool use_64_bit_indexes_for_dictionary = false; /// Output Date as UInt16 instead of Arrow DATE32 for backward compatibility. bool output_date_as_uint16 = false; + /// Output types having no conversion as raw binary data. If false - such types would raise UNKNOWN_TYPE exception. + bool output_unsupported_types_as_binary = false; }; + static std::shared_ptr calculateArrowSchema( + const ColumnsWithTypeAndName & header_columns, + const std::string & format_name, + const Chunk * chunk, + const Settings & settings, + std::optional columns_num = std::nullopt, + const std::optional> & column_to_field_id = std::nullopt + ); + + /// Because an arrow table can only have one dictionary per column, if the returned table is intended to be inserted into a larger table, + /// `cached_dictionary_values` should be provided to maintain this limitation. + static std::shared_ptr calculateArrowTable( + const ColumnsWithTypeAndName & header_columns, + const std::string & format_name, + const std::vector & chunks, + const Settings & settings, + size_t columns_num, + std::shared_ptr schema, + std::unordered_map * cached_dictionary_values = nullptr); + + CHColumnToArrowColumn(const Block & header, const std::string & format_name_, const Settings & settings_); CHColumnToArrowColumn(const ColumnsWithTypeAndName & header_columns_, const std::string & format_name_, const Settings & settings_); - /// Makes a copy of this converter. + /// Makes a copy of this converter. /// This can be useful to prepare for conversion in multiple threads. std::unique_ptr clone(bool copy_arrow_schema = false) const; diff --git a/src/Processors/Formats/Impl/ParquetBlockOutputFormat.cpp b/src/Processors/Formats/Impl/ParquetBlockOutputFormat.cpp index 3e938fddea10..73040ad8a528 100644 --- a/src/Processors/Formats/Impl/ParquetBlockOutputFormat.cpp +++ b/src/Processors/Formats/Impl/ParquetBlockOutputFormat.cpp @@ -339,7 +339,8 @@ void ParquetBlockOutputFormat::writeUsingArrow(std::vector chunks) CHColumnToArrowColumn::Settings { .output_string_as_string = format_settings.parquet.output_string_as_string, - .output_fixed_string_as_fixed_byte_array = format_settings.parquet.output_fixed_string_as_fixed_byte_array + .output_fixed_string_as_fixed_byte_array = format_settings.parquet.output_fixed_string_as_fixed_byte_array, + .output_unsupported_types_as_binary = format_settings.parquet.output_unsupported_types_as_binary, }); } diff --git a/src/Server/ArrowFlight/ArrowFlightServer.cpp b/src/Server/ArrowFlight/ArrowFlightServer.cpp new file mode 100644 index 000000000000..b70f2bebd29e --- /dev/null +++ b/src/Server/ArrowFlight/ArrowFlightServer.cpp @@ -0,0 +1,1228 @@ +#include + +#if USE_ARROWFLIGHT + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int UNKNOWN_EXCEPTION; + extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED; + extern const int UNKNOWN_SETTING; + extern const int SYNTAX_ERROR; +} + +namespace Setting +{ + extern const SettingsBool output_format_arrow_unsupported_types_as_binary; +} + + +using ArrowFlight::CallsData; +using ArrowFlight::Duration; +using ArrowFlight::hasTicketPrefix; +using ArrowFlight::hasPollDescriptorPrefix; +using ArrowFlight::PollDescriptorInfo; +using ArrowFlight::PollDescriptorWithExpirationTime; +using ArrowFlight::PollSession; +using ArrowFlight::Timestamp; + + +namespace +{ + /// Helper for std::visit with multiple lambda overloads + /// Usage: + /// std::variant v = 42; + /// auto result = std::visit(overloaded { + /// [](int i) { return std::to_string(i); }, + /// [](const std::string& s) { return s; }, + /// [](const auto& other) { return "unknown"; } + /// }, v); + template struct overloaded : Ts... { using Ts::operator()...; }; // NOLINT + template overloaded(Ts...) -> overloaded; + + String readFile(const String & filepath) + { + Poco::FileInputStream ifs(filepath); + String buf; + Poco::StreamCopier::copyToString(ifs, buf); + return buf; + } + + arrow::flight::Location addressToArrowLocation(const Poco::Net::SocketAddress & address_to_listen, bool use_tls) + { + auto ip_to_listen = address_to_listen.host(); + auto port_to_listen = address_to_listen.port(); + + /// Function arrow::flight::Location::ForGrpc*() builds an URL so it requires IPv6 address to be enclosed in brackets + String host_component = (ip_to_listen.family() == Poco::Net::AddressFamily::IPv6) ? ("[" + ip_to_listen.toString() + "]") : ip_to_listen.toString(); + + arrow::Result parse_location_status; + if (use_tls) + parse_location_status = arrow::flight::Location::ForGrpcTls(host_component, port_to_listen); + else + parse_location_status = arrow::flight::Location::ForGrpcTcp(host_component, port_to_listen); + + if (!parse_location_status.ok()) + { + throw Exception( + ErrorCodes::UNKNOWN_EXCEPTION, + "Invalid address {} for Arrow Flight Server: {}", + address_to_listen.toString(), + parse_location_status.status().ToString()); + } + + return std::move(parse_location_status).ValueOrDie(); + } + + [[nodiscard]] arrow::Result convertPathToSQL(const std::vector & path, bool for_put_operation) + { + if (path.size() != 1) + return arrow::Status::Invalid("Flight descriptor's path should be one-component (got ", path.size(), " components)"); + if (path[0].empty()) + return arrow::Status::Invalid("Flight descriptor's path should specify the name of a table"); + const String & table_name = path[0]; + if (for_put_operation) + return "INSERT INTO " + backQuoteIfNeed(table_name) + " FORMAT Arrow"; + return "SELECT * FROM " + backQuoteIfNeed(table_name); + } + + [[nodiscard]] arrow::Result convertGetPathToSQL(const std::vector & path) + { + return convertPathToSQL(path, /* for_put_operation = */ false); + } + + [[nodiscard]] arrow::Result convertPutPathToSQL(const std::vector & path) + { + return convertPathToSQL(path, /* for_put_operation = */ true); + } + + using DecodeResult = std::tuple>; + + [[nodiscard]] + arrow::Result decodeDescriptor(const arrow::flight::FlightDescriptor & descriptor, bool for_put_operation) + { + switch (descriptor.type) + { + case arrow::flight::FlightDescriptor::PATH: + { + auto sql_res = for_put_operation ? convertPutPathToSQL(descriptor.path) : convertGetPathToSQL(descriptor.path); + ARROW_RETURN_NOT_OK(sql_res); + return DecodeResult {sql_res.ValueUnsafe(), {}, {}, {}}; + } + case arrow::flight::FlightDescriptor::CMD: + { + if (!for_put_operation && hasPollDescriptorPrefix(descriptor.cmd)) + return arrow::Status::Invalid("Method GetFlightInfo cannot be called with a flight descriptor returned by method PollFlightInfo"); + + auto res = ArrowFlight::commandSelector(descriptor.cmd); + if (const auto * result_table = res.getTable()) + { + ARROW_RETURN_NOT_OK(*result_table); + return DecodeResult {{}, {}, {}, result_table->ValueUnsafe()}; + } + const auto * sql_set = res.getSQLSet(); + return DecodeResult {sql_set->sql, sql_set->schema_modifier, sql_set->block_modifier, {}}; + } + default: + return arrow::Status::TypeError("Flight descriptor has unknown type ", magic_enum::enum_name(descriptor.type)); + } + } + + /// For method doGet() the pipeline should have an output. + [[nodiscard]] arrow::Status checkPipelineIsPulling(const QueryPipeline & pipeline) + { + if (!pipeline.pulling()) + return arrow::Status::Invalid("Query doesn't allow pulling data, use method doPut() with this kind of query"); + return arrow::Status::OK(); + } + + /// We don't allow custom formats except "Arrow" because they can't work with ArrowFlight. + [[nodiscard]] arrow::Status checkNoCustomFormat(ASTPtr ast) + { + if (const auto * ast_with_output = dynamic_cast(ast.get())) + { + if (ast_with_output->format_ast && (getIdentifierName(ast_with_output->format_ast) != "Arrow")) + return arrow::Status::ExecutionError("Invalid format, only 'Arrow' format is supported"); + } + else if (const auto * insert = dynamic_cast(ast.get())) + { + if (!insert->format.empty() && insert->format != "Values" && insert->format != "Arrow") + return arrow::Status::ExecutionError("Invalid format (", insert->format, "), only 'Arrow' format is supported"); + } + return arrow::Status::OK(); + } + + /// Creates a converter to convert ClickHouse blocks to the Arrow format. + std::shared_ptr createCHToArrowConverter(const Block & header, ContextPtr query_context) + { + CHColumnToArrowColumn::Settings arrow_settings; + arrow_settings.output_string_as_string = true; + arrow_settings.output_unsupported_types_as_binary = query_context->getSettingsRef()[Setting::output_format_arrow_unsupported_types_as_binary]; + auto ch_to_arrow_converter = std::make_shared(header, "Arrow", arrow_settings); + ch_to_arrow_converter->initializeArrowSchema(); + return ch_to_arrow_converter; + } +} + + +ArrowFlightServer::ArrowFlightServer(IServer & server_, const Poco::Net::SocketAddress & address_to_listen_) + : server(server_) + , log(getLogger("ArrowFlightServer")) + , address_to_listen(address_to_listen_) + , tickets_lifetime_seconds(server.config().getUInt("arrowflight.tickets_lifetime_seconds", 600)) + , cancel_ticket_after_do_get(server.config().getBool("arrowflight.cancel_ticket_after_do_get", false)) + , poll_descriptors_lifetime_seconds(server.config().getUInt("arrowflight.poll_descriptors_lifetime_seconds", 600)) + , cancel_poll_descriptor_after_poll_flight_info(server.config().getBool("arrowflight.cancel_flight_descriptor_after_poll_flight_info", false)) + , calls_data( + std::make_unique( + tickets_lifetime_seconds ? std::make_optional(std::chrono::seconds{tickets_lifetime_seconds}) : std::optional{}, + poll_descriptors_lifetime_seconds ? std::make_optional(std::chrono::seconds{poll_descriptors_lifetime_seconds}) + : std::optional{}, + log)) +{ +} + +void ArrowFlightServer::start() +{ + chassert(!initialized && !stopped); + + bool use_tls = server.config().getBool("arrowflight.enable_ssl", false); + + auto location = addressToArrowLocation(address_to_listen, use_tls); + + arrow::flight::FlightServerOptions options(location); + options.auth_handler = std::make_unique(); + options.middleware.emplace_back(AUTHORIZATION_MIDDLEWARE_NAME, std::make_shared(server)); + + if (use_tls) + { + auto cert_path = server.config().getString("arrowflight.ssl_cert_file"); + auto key_path = server.config().getString("arrowflight.ssl_key_file"); + + auto cert = readFile(cert_path); + auto key = readFile(key_path); + + options.tls_certificates.push_back(arrow::flight::CertKeyPair{cert, key}); + } + + auto init_status = Init(options); + if (!init_status.ok()) + { + throw Exception(ErrorCodes::UNKNOWN_EXCEPTION, "Failed init Arrow Flight Server: {}", init_status.ToString()); + } + + initialized = true; + + server_thread.emplace([this] + { + try + { + DB::setThreadName(ThreadName::ARROW_FLIGHT_SERVER); + if (stopped) + return; + auto serve_status = Serve(); + if (!serve_status.ok()) + LOG_ERROR(log, "Failed to serve Arrow Flight: {}", serve_status.ToString()); + } + catch (...) + { + tryLogCurrentException(log, "Failed to serve Arrow Flight"); + } + }); + + if (tickets_lifetime_seconds || poll_descriptors_lifetime_seconds) + { + cleanup_thread.emplace([this] + { + try + { + DB::setThreadName(ThreadName::ARROW_FLIGHT_EXPR); + while (!stopped) + { + calls_data->waitNextExpirationTime(); + calls_data->cancelExpired(); + } + } + catch (...) + { + tryLogCurrentException(log, "Failed to cleanup"); + } + }); + } +} + +ArrowFlightServer::~ArrowFlightServer() = default; + +void ArrowFlightServer::stop() +{ + if (!initialized) + return; + + if (!stopped.exchange(true)) + { + try + { + auto status = Shutdown(); + if (!status.ok()) + LOG_ERROR(log, "Failed to shutdown Arrow Flight: {}", status.ToString()); + status = Wait(); + if (!status.ok()) + LOG_ERROR(log, "Failed to wait for shutdown Arrow Flight: {}", status.ToString()); + } + catch (...) + { + tryLogCurrentException(log, "Failed to shutdown Arrow Flight"); + } + if (server_thread) + { + server_thread->join(); + server_thread.reset(); + } + + calls_data->stopWaitingNextExpirationTime(); + if (cleanup_thread) + { + cleanup_thread->join(); + cleanup_thread.reset(); + } + calls_data.reset(); + } +} + +UInt16 ArrowFlightServer::portNumber() const +{ + return address_to_listen.port(); +} + +static size_t calculateTableBytes(const std::shared_ptr& table) +{ + int64_t total_bytes = 0; + for (const auto & chunked_array : table->columns()) + for (const auto & array : chunked_array->chunks()) + for (const auto& buffer : array->data()->buffers) + if (buffer) + total_bytes += buffer->size(); + return total_bytes; +} + +static ColumnsWithTypeAndName getHeader(const ColumnsWithTypeAndName & columns) +{ + ColumnsWithTypeAndName res; + for (const auto & column : columns) + res.emplace_back(column.cloneEmpty()); + return res; +} + +static std::shared_ptr getEmptyArrowTable(std::shared_ptr schema) +{ + size_t columns_num = schema->num_fields(); + std::vector> empty_columns; + empty_columns.reserve(columns_num); + + for (size_t i = 0; i < columns_num; ++i) + empty_columns.push_back(std::make_shared(arrow::ArrayVector{}, schema->field(static_cast(i))->type())); + + return arrow::Table::Make(schema, empty_columns); +} + +static arrow::Result, std::vector>>> executeSQLtoTables_impl( + const std::shared_ptr & session, + const std::string & sql, + bool single_table, + ArrowFlight::SchemaModifier schema_modifier = nullptr, + ArrowFlight::BlockModifier block_modifier = nullptr +) +{ + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. + QueryScope query_scope = QueryScope::create(query_context); + + std::shared_ptr schema; + std::vector> tables; + + auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); + + bool query_finished = false; + bool handling_exception = false; + SCOPE_EXIT({ + if (query_finished) + block_io.onFinish(); + else if (!handling_exception) + block_io.onCancelOrConnectionLoss(); + }); + + try + { + ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); + ARROW_RETURN_NOT_OK(checkPipelineIsPulling(block_io.pipeline)); + + PullingPipelineExecutor executor{block_io.pipeline}; + schema = CHColumnToArrowColumn::calculateArrowSchema( + executor.getHeader().getColumnsWithTypeAndName(), + "Arrow", + nullptr, + {.output_string_as_string = true, .output_unsupported_types_as_binary = query_context->getSettingsRef()[Setting::output_format_arrow_unsupported_types_as_binary]}); + + if (schema_modifier) + { + auto status = schema_modifier(schema); + ARROW_RETURN_NOT_OK(status); + schema = status.ValueUnsafe(); + } + + std::optional header; + std::vector chunks; + Block block; + while (executor.pull(block)) + { + if (!block.empty()) + { + if (block_modifier) + block_modifier(query_context, block); + if (!header) + header = getHeader(block.getColumnsWithTypeAndName()); + chunks.emplace_back(Chunk{block.getColumns(), block.rows()}); + if (!single_table) + { + tables.emplace_back( + CHColumnToArrowColumn::calculateArrowTable( + *header, "Arrow", chunks, + {.output_string_as_string = true, .output_unsupported_types_as_binary = query_context->getSettingsRef()[Setting::output_format_arrow_unsupported_types_as_binary]}, + header->size(), schema)); + chunks.clear(); + } + } + } + + if (!header) + tables.emplace_back(getEmptyArrowTable(schema)); + else if (single_table) + tables.emplace_back( + CHColumnToArrowColumn::calculateArrowTable( + *header, "Arrow", chunks, + {.output_string_as_string = true, .output_unsupported_types_as_binary = query_context->getSettingsRef()[Setting::output_format_arrow_unsupported_types_as_binary]}, + header->size(), schema)); + + query_finished = true; + } + catch (...) + { + handling_exception = true; + block_io.onException(); + throw; + } + + return std::tuple{schema, tables}; +} + +static arrow::Result, std::vector>>> executeSQLtoTables( + const std::shared_ptr & session, + const std::string & sql, + ArrowFlight::SchemaModifier schema_modifier = nullptr, + ArrowFlight::BlockModifier block_modifier = nullptr +) +{ + return executeSQLtoTables_impl(session, sql, false, schema_modifier, block_modifier); +} + +static arrow::Result, std::shared_ptr>> executeSQLtoTable( + const std::shared_ptr & session, + const std::string & sql, + ArrowFlight::SchemaModifier schema_modifier = nullptr, + ArrowFlight::BlockModifier block_modifier = nullptr +) +{ + auto res = executeSQLtoTables_impl(session, sql, true, schema_modifier, block_modifier); + ARROW_RETURN_NOT_OK(res); + return std::tuple{std::get<0>(res.ValueUnsafe()), std::get<1>(res.ValueUnsafe()).front()}; +} + +arrow::Status ArrowFlightServer::GetFlightInfo( + const arrow::flight::ServerCallContext & context, + const arrow::flight::FlightDescriptor & request, + std::unique_ptr * info) +{ + auto impl = [&] + { + LOG_INFO(log, "GetFlightInfo is called for descriptor {}", request.ToString()); + + const auto & auth = AuthMiddleware::get(context); + auto session = auth.getSession(); + + std::string sql; + ArrowFlight::SchemaModifier schema_modifier; + ArrowFlight::BlockModifier block_modifier; + std::shared_ptr table; + std::shared_ptr schema; + + ARROW_ASSIGN_OR_RAISE(std::tie(sql, schema_modifier, block_modifier, table), decodeDescriptor(request, false)) + chassert(!sql.empty() || table); + + std::vector endpoints; + int64_t total_rows = 0; + int64_t total_bytes = 0; + + if (table) + { + schema = table->schema(); + total_rows = table->num_rows(); + total_bytes = calculateTableBytes(table); + auto ticket_info = calls_data->createTicket(table); + arrow::flight::FlightEndpoint endpoint; + endpoint.ticket = arrow::flight::Ticket(ticket_info->ticket); + endpoint.expiration_time = ticket_info->expiration_time; + endpoints.emplace_back(endpoint); + } + else + { + // We generate a table for every chunk of data, which then produces ticket for every table + // so clients can parallelize data retrieval. + // However, it's unclear if this is necessary since we later indicate that data is ordered + // and all endpoints are local. This forces clients to request data through the same connection, + // and even with gRPC, clients are forced to prioritize the order. + // TODO: Consider single ticket optimization for ordered local data to reduce overhead (executeSQLtoTable) + std::vector> tables; + ARROW_ASSIGN_OR_RAISE(std::tie(schema, tables) , executeSQLtoTables(session, sql, schema_modifier, block_modifier)) + + for (auto & t : tables) + { + total_rows += t->num_rows(); + total_bytes += calculateTableBytes(t); + auto ticket_info = calls_data->createTicket(t); + arrow::flight::FlightEndpoint endpoint; + endpoint.ticket = arrow::flight::Ticket(ticket_info->ticket); + endpoint.expiration_time = ticket_info->expiration_time; + endpoints.emplace_back(endpoint); + } + } + + auto flight_info_res = arrow::flight::FlightInfo::Make( + *schema, + request, + endpoints, + total_rows, + total_bytes, + /* ordered = */ true); + + ARROW_RETURN_NOT_OK(flight_info_res); + *info = std::make_unique(std::move(flight_info_res).ValueUnsafe()); + + LOG_INFO(log, "GetFlightInfo returns flight info {}", (*info)->ToString()); + return arrow::Status::OK(); + }; + return tryRunAndLogIfError("GetFlightInfo", impl); +} + + +arrow::Status ArrowFlightServer::GetSchema( + const arrow::flight::ServerCallContext & context, + const arrow::flight::FlightDescriptor & request, + std::unique_ptr * schema_result) +{ + auto impl = [&] + { + LOG_INFO(log, "GetSchema is called for descriptor {}", request.ToString()); + + const auto & auth = AuthMiddleware::get(context); + auto session = auth.getSession(); + + std::shared_ptr schema; + + if ((request.type == arrow::flight::FlightDescriptor::CMD) && hasPollDescriptorPrefix(request.cmd)) + { + const String & poll_descriptor = request.cmd; + ARROW_RETURN_NOT_OK(calls_data->extendPollDescriptorExpirationTime(poll_descriptor)); + auto poll_info_res = calls_data->getPollDescriptorInfo(poll_descriptor); + ARROW_RETURN_NOT_OK(poll_info_res); + const auto & poll_info = poll_info_res.ValueOrDie(); + schema = poll_info->schema; + } + else + { + std::string sql; + ArrowFlight::SchemaModifier schema_modifier; + ArrowFlight::BlockModifier block_modifier; + std::shared_ptr table; + + ARROW_ASSIGN_OR_RAISE(std::tie(sql, schema_modifier, block_modifier, table), decodeDescriptor(request, false)) + chassert(!sql.empty() || table); + + if (table) + schema = table->schema(); + else + { + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. + QueryScope query_scope = QueryScope::create(query_context); + + auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); + + bool query_finished = false; + bool handling_exception = false; + SCOPE_EXIT({ + if (query_finished) + block_io.onFinish(); + else if (!handling_exception) + block_io.onCancelOrConnectionLoss(); + }); + + try + { + ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); + ARROW_RETURN_NOT_OK(checkPipelineIsPulling(block_io.pipeline)); + + PullingPipelineExecutor executor{block_io.pipeline}; + + schema = CHColumnToArrowColumn::calculateArrowSchema( + executor.getHeader().getColumnsWithTypeAndName(), "Arrow", nullptr, + {.output_string_as_string = true, .output_unsupported_types_as_binary = query_context->getSettingsRef()[Setting::output_format_arrow_unsupported_types_as_binary]}); + if (schema_modifier) + { + auto status = schema_modifier(schema); + ARROW_RETURN_NOT_OK(status); + schema = status.ValueUnsafe(); + } + + query_finished = true; + } + catch (...) + { + handling_exception = true; + block_io.onException(); + throw; + } + } + } + + auto schema_res = arrow::flight::SchemaResult::Make(*schema); + ARROW_RETURN_NOT_OK(schema_res); + *schema_result = std::make_unique(*std::move(schema_res).ValueUnsafe()); + + LOG_INFO(log, "GetSchema returns schema {}", schema->ToString()); + return arrow::Status::OK(); + }; + return tryRunAndLogIfError("GetSchema", impl); +} + + +arrow::Status ArrowFlightServer::PollFlightInfo( + const arrow::flight::ServerCallContext & context, + const arrow::flight::FlightDescriptor & request, + std::unique_ptr * info) +{ + auto impl = [&] + { + LOG_INFO(log, "PollFlightInfo is called for descriptor {}", request.ToString()); + + const auto & auth = AuthMiddleware::get(context); + auto session = auth.getSession(); + + std::shared_ptr poll_info; + std::shared_ptr schema; + std::optional next_poll_descriptor; + bool should_cancel_poll_descriptor = false; + + arrow::flight::FlightDescriptor original_flight_descriptor; + std::string query_id; + + if ((request.type == arrow::flight::FlightDescriptor::CMD) && hasPollDescriptorPrefix(request.cmd)) + { + const String & poll_descriptor = request.cmd; + ARROW_RETURN_NOT_OK(evaluatePollDescriptor(poll_descriptor)); + ARROW_RETURN_NOT_OK(calls_data->extendPollDescriptorExpirationTime(poll_descriptor)); + auto poll_info_res = calls_data->getPollDescriptorInfo(poll_descriptor); + ARROW_RETURN_NOT_OK(poll_info_res); + poll_info = poll_info_res.ValueOrDie(); + original_flight_descriptor = poll_info->original_flight_descriptor; + query_id = poll_info->query_id; + schema = poll_info->schema; + if (poll_info->next_poll_descriptor) + next_poll_descriptor = calls_data->getPollDescriptorWithExpirationTime(*poll_info->next_poll_descriptor); + should_cancel_poll_descriptor = cancel_poll_descriptor_after_poll_flight_info; + } + else + { + std::string sql; + ArrowFlight::SchemaModifier schema_modifier; + ArrowFlight::BlockModifier block_modifier; + std::shared_ptr table; + + ARROW_ASSIGN_OR_RAISE(std::tie(sql, schema_modifier, block_modifier, table), decodeDescriptor(request, false)) + chassert(!sql.empty() || table); + + if (table) + { + auto ticket_info = calls_data->createTicket(table); + std::vector endpoints; + arrow::flight::FlightEndpoint endpoint; + endpoint.ticket = arrow::flight::Ticket(ticket_info->ticket); + endpoint.expiration_time = ticket_info->expiration_time; + endpoints.emplace_back(endpoint); + + auto flight_info_res = arrow::flight::FlightInfo::Make(*table->schema(), request, endpoints, table->num_rows(), calculateTableBytes(table), /* ordered = */ true); + ARROW_RETURN_NOT_OK(flight_info_res); + auto flight_info = std::make_unique(flight_info_res.ValueOrDie()); + *info = std::make_unique(std::move(flight_info), std::nullopt, std::nullopt, std::nullopt); + + LOG_INFO(log, "PollFlightInfo returns {}", (*info)->ToString()); + return arrow::Status::OK(); + } + + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. + + auto thread_group = ThreadGroup::createForQuery(query_context); + CurrentThread::attachToGroup(thread_group); + SCOPE_EXIT({ CurrentThread::detachFromGroupIfNotDetached(); }); + + auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); + + bool block_io_owned_here = true; + SCOPE_EXIT({ + if (block_io_owned_here) + block_io.onCancelOrConnectionLoss(); + }); + + ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); + ARROW_RETURN_NOT_OK(checkPipelineIsPulling(block_io.pipeline)); + + block_io_owned_here = false; + auto poll_session = std::make_unique(query_context, thread_group, std::move(block_io), schema_modifier, block_modifier); + + schema = poll_session->getSchema(); + + original_flight_descriptor = request; + query_id = query_context->getCurrentQueryId(); + auto next_info = calls_data->createPollDescriptor(std::move(poll_session), original_flight_descriptor, query_id); + next_poll_descriptor = *next_info; + } + + std::vector endpoints; + int64_t total_rows = 0; + int64_t total_bytes = 0; + + while (poll_info) + { + if (poll_info->ticket) + { + arrow::flight::FlightEndpoint endpoint; + endpoint.ticket = arrow::flight::Ticket{*poll_info->ticket}; + endpoint.expiration_time = calls_data->getTicketExpirationTime(*poll_info->ticket); + endpoints.emplace_back(endpoint); + } + if (poll_info->rows) + total_rows += *poll_info->rows; + if (poll_info->bytes) + total_bytes += *poll_info->bytes; + poll_info = poll_info->previous_info; + } + std::reverse(endpoints.begin(), endpoints.end()); + + auto flight_info_res = arrow::flight::FlightInfo::Make(*schema, original_flight_descriptor, endpoints, total_rows, total_bytes, /* ordered = */ true, query_id); + ARROW_RETURN_NOT_OK(flight_info_res); + std::unique_ptr flight_info = std::make_unique(flight_info_res.ValueOrDie()); + + std::optional next; + std::optional expiration_time; + if (next_poll_descriptor) + { + next = arrow::flight::FlightDescriptor::Command(next_poll_descriptor->poll_descriptor); + expiration_time = next_poll_descriptor->expiration_time; + } + + *info = std::make_unique(std::move(flight_info), std::move(next), std::nullopt, expiration_time); + + if (should_cancel_poll_descriptor) + calls_data->cancelPollDescriptor(request.cmd); + + LOG_INFO(log, "PollFlightInfo returns {}", (*info)->ToString()); + return arrow::Status::OK(); + }; + return tryRunAndLogIfError("PollFlightInfo", impl); +} + + +/// evaluatePollDescriptors() pulls a block from the query pipeline. +/// This function blocks until it either gets a nonempty block from the query pipeline or finds out that there will be no blocks anymore. +/// +/// NOTE: The current implementation doesn't allow to set a timeout to avoid blocking calls as it's suggested in the documentation +/// for PollFlightInfo (see https://arrow.apache.org/docs/format/Flight.html#downloading-data-by-running-a-heavy-query). +arrow::Status ArrowFlightServer::evaluatePollDescriptor(const String & poll_descriptor) +{ + auto poll_session_res = calls_data->startEvaluation(poll_descriptor); + ARROW_RETURN_NOT_OK(poll_session_res); + auto poll_session = std::move(poll_session_res).ValueOrDie(); + + if (!poll_session) + { + /// Already evaluated. + auto info_res = calls_data->getPollDescriptorInfo(poll_descriptor); + ARROW_RETURN_NOT_OK(info_res); + const auto & info = info_res.ValueOrDie(); + if (!info->evaluated) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Session is not attached to non-evaluated poll descriptor {}", poll_descriptor); + return *info->status; + } + + ThreadGroupSwitcher thread_group_switcher{poll_session->getThreadGroup(), ThreadName::ARROW_FLIGHT}; + + std::optional ticket; + try + { + UInt64 rows = 0; + UInt64 bytes = 0; + Block block; + while (poll_session->getNextBlock(block)) + { + if (block.empty()) + continue; + + auto header = getHeader(block.getColumnsWithTypeAndName()); + rows = block.rows(); + bytes = block.bytes(); + std::vector chunks; + chunks.emplace_back(Chunk{std::move(block).getColumns(), rows}); + std::shared_ptr table = CHColumnToArrowColumn::calculateArrowTable( + header, "Arrow", chunks, + {.output_string_as_string = true, .output_unsupported_types_as_binary = poll_session->queryContext()->getSettingsRef()[Setting::output_format_arrow_unsupported_types_as_binary]}, + header.size(), poll_session->getSchema()); + auto ticket_info = calls_data->createTicket(table); + ticket = ticket_info->ticket; + break; + } + + if (!ticket) + poll_session->onFinish(); + calls_data->endEvaluation(poll_descriptor, ticket, rows, bytes, !ticket); + } + catch (...) + { + tryLogCurrentException(log, "Poll: Failed to get next block"); + auto error_status = arrow::Status::ExecutionError("Poll: Failed to get next block: ", getCurrentExceptionMessage(/* with_stacktrace = */ false)); + calls_data->endEvaluationWithError(poll_descriptor, error_status); + poll_session->onException(); + return error_status; + } + + auto info_res = calls_data->getPollDescriptorInfo(poll_descriptor); + if (!info_res.ok()) + { + if (ticket) + poll_session->onCancelOrConnectionLoss(); + return info_res.status(); + } + const auto & info = info_res.ValueOrDie(); + if (!ticket) + calls_data->eraseFlightDescriptorMapByDescriptor(poll_descriptor); + else + calls_data->createPollDescriptor(std::move(poll_session), info); + + return arrow::Status::OK(); +} + + +arrow::Status ArrowFlightServer::DoGet( + const arrow::flight::ServerCallContext & context, + const arrow::flight::Ticket & request, + std::unique_ptr * stream) +{ + auto impl = [&] + { + LOG_INFO(log, "DoGet is called for ticket {}", request.ticket); + std::vector chunks; + std::shared_ptr table; + bool should_cancel_ticket = false; + + const auto & auth = AuthMiddleware::get(context); + auto session = auth.getSession(); + + if (hasTicketPrefix(request.ticket)) + { + auto ticket_info_res = calls_data->getTicketInfo(request.ticket); + ARROW_RETURN_NOT_OK(ticket_info_res); + const auto & ticket_info = ticket_info_res.ValueOrDie(); + table = ticket_info->arrow_table; + should_cancel_ticket = cancel_ticket_after_do_get; + } + else + { + const String & sql = request.ticket; + + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. + QueryScope query_scope = QueryScope::create(query_context); + + auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); + + bool query_finished = false; + bool handling_exception = false; + SCOPE_EXIT({ + if (query_finished) + block_io.onFinish(); + else if (!handling_exception) + block_io.onCancelOrConnectionLoss(); + }); + + try + { + ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); + ARROW_RETURN_NOT_OK(checkPipelineIsPulling(block_io.pipeline)); + + PullingPipelineExecutor executor{block_io.pipeline}; + + Block block; + while (executor.pull(block)) + chunks.emplace_back(Chunk(block.getColumns(), block.rows())); + + auto header = executor.getHeader(); + auto ch_to_arrow_converter = createCHToArrowConverter(header, query_context); + ch_to_arrow_converter->chChunkToArrowTable(table, chunks, header.columns()); + + query_finished = true; + } + catch (...) + { + handling_exception = true; + block_io.onException(); + throw; + } + } + + auto stream_res = arrow::RecordBatchReader::MakeFromIterator( + arrow::Iterator>{arrow::TableBatchReader{table}}, table->schema()); + ARROW_RETURN_NOT_OK(stream_res); + *stream = std::make_unique(stream_res.ValueOrDie()); + + if (should_cancel_ticket) + calls_data->cancelTicket(request.ticket); + + LOG_INFO(log, "DoGet succeeded"); + return arrow::Status::OK(); + }; + return tryRunAndLogIfError("DoGet", impl); +} + + +arrow::Status ArrowFlightServer::DoPut( + const arrow::flight::ServerCallContext & context, + std::unique_ptr reader, + std::unique_ptr writer) +{ + auto impl = [&] + { + const auto & request = reader->descriptor(); + LOG_INFO(log, "DoPut is called for descriptor {}", request.ToString()); + + const auto & auth = AuthMiddleware::get(context); + auto session = auth.getSession(); + + bool dont_write_flight_sql_metadata = !ArrowFlight::flightDescriptorIsArrowFlightSqlCommand(request); + + std::string sql; + ArrowFlight::SchemaModifier schema_modifier; + ArrowFlight::BlockModifier block_modifier; + std::shared_ptr table; + + ARROW_ASSIGN_OR_RAISE(std::tie(sql, schema_modifier, block_modifier, table), decodeDescriptor(request, true)) + /// DoPut command should only produce sql query + chassert(!sql.empty() && !schema_modifier && !block_modifier && !table); + + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. + QueryScope query_scope = QueryScope::create(query_context); + + auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); + + bool query_finished = false; + bool handling_exception = false; + SCOPE_EXIT({ + if (query_finished) + block_io.onFinish(); + else if (!handling_exception) + block_io.onCancelOrConnectionLoss(); + }); + + try + { + ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); + auto & pipeline = block_io.pipeline; + + if (pipeline.pushing()) + { + Block header = pipeline.getHeader(); + auto input = std::make_shared(std::move(reader), header, query_context); + pipeline.complete(Pipe(std::move(input))); + } + else if (pipeline.pulling()) + { + Block header = pipeline.getHeader(); + auto output = std::make_shared(std::make_shared(header)); + pipeline.complete(std::move(output)); + } + + if (pipeline.completed()) + { + CompletedPipelineExecutor executor(pipeline); + executor.execute(); + } + + query_finished = true; + } + catch (...) + { + handling_exception = true; + block_io.onException(); + throw; + } + + if (!dont_write_flight_sql_metadata) + { + arrow::flight::protocol::sql::DoPutUpdateResult update_result; + if (auto element = query_context->getProcessListElement()) + update_result.set_record_count(element->getInfo().written_rows); + else + update_result.set_record_count(0); + + ARROW_RETURN_NOT_OK(writer->WriteMetadata(*arrow::Buffer::FromString(update_result.SerializeAsString()))); + } + + LOG_INFO(log, "DoPut succeeded"); + + return arrow::Status::OK(); + }; + return tryRunAndLogIfError("DoPut", impl); +} + + +arrow::Status ArrowFlightServer::tryRunAndLogIfError(std::string_view method_name, std::function && func) const +{ + DB::setThreadName(ThreadName::ARROW_FLIGHT); + ThreadStatus thread_status; + try + { + auto status = std::move(func)(); + if (!status.ok()) + LOG_ERROR(log, "{} failed: {}", method_name, status.ToString()); + return status; + } + catch (...) + { + tryLogCurrentException(log, fmt::format("{} failed", method_name)); + return arrow::Status::ExecutionError(method_name, " failed: ", getCurrentExceptionMessage(/* with_stacktrace = */ false)); + } +} + + +arrow::Status ArrowFlightServer::DoAction( + const arrow::flight::ServerCallContext & context, + const arrow::flight::Action & action, + std::unique_ptr * result_stream) +{ + auto impl = [&] + { + LOG_INFO(log, "DoAction is called for action {} {}", action.type, action.ToString()); + + const auto & auth = AuthMiddleware::get(context); + auto session = auth.getSession(); + + std::vector results; + + if (action.type == arrow::flight::ActionType::kCancelFlightInfo.type) + { + if (!action.body) + return arrow::Status::Invalid("Invalid empty CancelFlightInfo action."); + ARROW_ASSIGN_OR_RAISE(auto request, arrow::flight::CancelFlightInfoRequest::Deserialize({action.body->data_as(), static_cast(action.body->size())})) + LOG_DEBUG(log, "CancelFlightInfo request {}", request.ToString()); + auto query_id = request.info->app_metadata(); + auto result = arrow::flight::CancelFlightInfoResult{arrow::flight::CancelStatus::kNotCancellable}; + + if (!query_id.empty()) + { + auto & process_list = server.context()->getProcessList(); + auto cancel_result = process_list.sendCancelToQuery(query_id, auth.getUsername()); + if (cancel_result == CancellationCode::CancelSent) + { + result = arrow::flight::CancelFlightInfoResult{arrow::flight::CancelStatus::kCancelled}; + + for (const auto & pd : calls_data->collectPollDescriptorsForQueryId(query_id)) + calls_data->cancelPollDescriptor(pd); + } + } + + ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()) + ARROW_ASSIGN_OR_RAISE(auto packed_result, arrow::Result{arrow::flight::Result{arrow::Buffer::FromString(std::move(serialized))}}) + results.push_back(std::move(packed_result)); + } + else if (action.type == arrow::flight::ActionType::kSetSessionOptions.type) + { + if (!action.body) + return arrow::Status::Invalid("Invalid empty SetSessionOptions action."); + ARROW_ASSIGN_OR_RAISE(auto request, arrow::flight::SetSessionOptionsRequest::Deserialize({action.body->data_as(), static_cast(action.body->size())})) + arrow::flight::SetSessionOptionsResult result; + + auto query_context = session->makeQueryContext(); + auto session_context = query_context->getSessionContext(); + + /// Convert Arrow Flight SessionOptionValue to a string representation + /// suitable for Context::setSetting(). + auto to_string_value = overloaded { + [](const std::string & str) { return str; }, + [](bool b) { return std::string(b ? "true" : "false"); }, + [](int64_t v) { return std::to_string(v); }, + [](double v) { return fmt::format("{}", v); }, + [](const std::vector & strings) + { + std::string res = "["; + for (size_t i = 0; i < strings.size(); ++i) + { + if (i > 0) res += ","; + res += quoteString(strings[i]); + } + res += "]"; + return res; + }, + /// std::monostate is deliberately excluded here — it means "reset to default" + /// and is handled separately instead of calling this visitor. + [](const std::monostate &) -> std::string + { + chassert(false && "std::monostate should be handled separately instead of calling this visitor"); + return ""; + } + }; + + for (const auto & [setting, value] : request.session_options) + { + if (!isValidIdentifier(setting)) + { + result.errors[setting] = arrow::flight::SetSessionOptionsResult::Error{ + arrow::flight::SetSessionOptionErrorValue::kInvalidName + }; + continue; + } + + try + { + if (std::holds_alternative(value)) + { + /// std::monostate means "reset to default" (SET setting = DEFAULT). + session_context->resetSettingsToDefaultValue({setting}); + } + else + { + auto string_value = std::visit(to_string_value, value); + SettingChange change{setting, Field{string_value}}; + query_context->checkSettingsConstraints(change, SettingSource::QUERY); + session_context->setSetting(setting, string_value); + } + } + catch (DB::Exception & e) + { + auto error_value = [&]() + { + if (e.code() == ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED || e.code() == ErrorCodes::SYNTAX_ERROR) + return arrow::flight::SetSessionOptionErrorValue::kInvalidValue; + else if (e.code() == ErrorCodes::UNKNOWN_SETTING) + return arrow::flight::SetSessionOptionErrorValue::kInvalidName; + else + return arrow::flight::SetSessionOptionErrorValue::kUnspecified; + }(); + + result.errors[setting] = arrow::flight::SetSessionOptionsResult::Error{error_value}; + } + } + + ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()) + ARROW_ASSIGN_OR_RAISE(auto packed_result, arrow::Result{arrow::flight::Result{arrow::Buffer::FromString(std::move(serialized))}}) + + results.push_back(std::move(packed_result)); + } + else if (action.type == arrow::flight::ActionType::kGetSessionOptions.type) + { + std::string_view body_view = action.body + ? std::string_view{action.body->data_as(), static_cast(action.body->size())} + : std::string_view{}; + ARROW_RETURN_NOT_OK(arrow::flight::GetSessionOptionsRequest::Deserialize(body_view)); + arrow::flight::GetSessionOptionsResult result; + + auto execute_res = executeSQLtoTable(session, "SELECT name, value FROM system.settings"); + ARROW_RETURN_NOT_OK(execute_res); + auto [_, table] = execute_res.ValueUnsafe(); + const auto & names = table->column(0); + const auto & values = table->column(1); + + if (names->num_chunks() != values->num_chunks()) + return arrow::Status::Invalid("Unexpected chunk layout mismatch for settings columns"); + + for (int chunk_idx = 0; chunk_idx < names->num_chunks(); ++chunk_idx) + { + const auto & name_chunk_any = names->chunk(chunk_idx); + const auto & value_chunk_any = values->chunk(chunk_idx); + + if (name_chunk_any->type_id() != arrow::Type::STRING || value_chunk_any->type_id() != arrow::Type::STRING) + return arrow::Status::TypeError("Expected STRING chunks in settings result"); + + if (name_chunk_any->length() != value_chunk_any->length()) + return arrow::Status::Invalid("Mismatched chunk lengths for settings columns"); + + const auto & name_chunk = static_cast(*name_chunk_any); + const auto & value_chunk = static_cast(*value_chunk_any); + + for (int64_t i = 0; i < name_chunk.length(); ++i) + { + if (name_chunk.IsNull(i) || value_chunk.IsNull(i)) + continue; + result.session_options[name_chunk.GetString(i)] = value_chunk.GetString(i); + } + } + + ARROW_ASSIGN_OR_RAISE(auto serialized, result.SerializeToString()) + ARROW_ASSIGN_OR_RAISE(auto packed_result, arrow::Result{arrow::flight::Result{arrow::Buffer::FromString(std::move(serialized))}}) + + results.push_back(std::move(packed_result)); + } + else + return arrow::Status::NotImplemented(action.type, " is not supported"); + + *result_stream = std::make_unique(std::move(results)); + return arrow::Status::OK(); + }; + return tryRunAndLogIfError("DoAction", impl); +} + +} + +#endif diff --git a/src/Server/ArrowFlightHandler.h b/src/Server/ArrowFlight/ArrowFlightServer.h similarity index 88% rename from src/Server/ArrowFlightHandler.h rename to src/Server/ArrowFlight/ArrowFlightServer.h index 7648a9822d8d..8fa37e363214 100644 --- a/src/Server/ArrowFlightHandler.h +++ b/src/Server/ArrowFlight/ArrowFlightServer.h @@ -3,6 +3,7 @@ #include "config.h" #if USE_ARROWFLIGHT + #include #include #include @@ -11,12 +12,14 @@ namespace DB { -class ArrowFlightHandler : public IGRPCServer, public arrow::flight::FlightServerBase +namespace ArrowFlight { class CallsData; } + +class ArrowFlightServer : public IGRPCServer, public arrow::flight::FlightServerBase { public: - explicit ArrowFlightHandler(IServer & server_, const Poco::Net::SocketAddress & address_to_listen_); + explicit ArrowFlightServer(IServer & server_, const Poco::Net::SocketAddress & address_to_listen_); - ~ArrowFlightHandler() override; + ~ArrowFlightServer() override; void start() override; @@ -75,8 +78,7 @@ class ArrowFlightHandler : public IGRPCServer, public arrow::flight::FlightServe const UInt64 poll_descriptors_lifetime_seconds; const bool cancel_poll_descriptor_after_poll_flight_info; - class CallsData; - std::unique_ptr calls_data; + std::unique_ptr calls_data; }; } diff --git a/src/Server/ArrowFlight/AuthMiddleware.cpp b/src/Server/ArrowFlight/AuthMiddleware.cpp new file mode 100644 index 000000000000..77925185a054 --- /dev/null +++ b/src/Server/ArrowFlight/AuthMiddleware.cpp @@ -0,0 +1,256 @@ +#include + +#include +#include +#include +#include +#include + +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int INVALID_SESSION_TIMEOUT; +} + +void AuthMiddleware::SendingHeaders(arrow::flight::AddCallHeaders * outgoing_headers) +{ + if (!token.empty()) + outgoing_headers->AddHeader(AUTHORIZATION_HEADER, "Bearer " + token); +} + +void AuthMiddleware::CallCompleted(const arrow::Status & /*status*/) +{ + if (!session) + return; + + if (!session_id.empty()) + { + if (session_close) + session->closeSession(session_id); + else + session->releaseSessionID(); + } +} + + +namespace +{ + std::chrono::steady_clock::duration parseSessionTimeout( + const Poco::Util::AbstractConfiguration & config, + unsigned query_session_timeout) + { + unsigned session_timeout = config.getInt("default_session_timeout", 60); + + if (query_session_timeout) + { + session_timeout = query_session_timeout; + unsigned max_session_timeout = config.getUInt("max_session_timeout", 3600); + + if (session_timeout > max_session_timeout) + throw Exception(ErrorCodes::INVALID_SESSION_TIMEOUT, "Session timeout '{}' is larger than max_session_timeout: {}. " + "Maximum session timeout could be modified in configuration file.", + session_timeout, max_session_timeout); + } + + return std::chrono::seconds(session_timeout); + } + + std::optional> getCredentialsFromBasicHeader(const arrow::flight::CallHeaders & headers) + { + auto it = std::ranges::find_if(headers, [](const auto & p) { return Poco::toLower(std::string(p.first)) == "authorization"; }); + if (it == headers.end()) + return std::nullopt; + + const std::string basic_prefix = "basic "; + const auto & auth_str = it->second; + + if (!Poco::toLower(std::string(auth_str)).starts_with(basic_prefix)) + return std::nullopt; + + auto credentials = base64Decode(std::string(auth_str.substr(basic_prefix.size()))); + + auto pos = credentials.find(':'); + if (pos == std::string::npos) + return {{credentials, ""}}; + + return {{credentials.substr(0, pos), credentials.substr(pos+1)}}; + } + + std::optional getTokenFromBearerHeader(const arrow::flight::CallHeaders & headers) + { + auto it = std::ranges::find_if(headers, [](const auto & p) { return Poco::toLower(std::string(p.first)) == "authorization"; }); + if (it == headers.end()) + return std::nullopt; + + const std::string bearer_prefix = "bearer "; + const auto & auth_str = it->second; + + if (!Poco::toLower(std::string(auth_str)).starts_with(bearer_prefix)) + return std::nullopt; + + return std::string(auth_str.substr(bearer_prefix.size())); + } + + /// Extracts the client's address from the call context. + Poco::Net::SocketAddress getClientAddress(const arrow::flight::ServerCallContext & context) + { + /// Returns a string like ipv4:127.0.0.1:55930 or ipv6:%5B::1%5D:55930 + String uri_encoded_peer = context.peer(); + + constexpr const std::string_view ipv4_prefix = "ipv4:"; + constexpr const std::string_view ipv6_prefix = "ipv6:"; + + bool ipv4 = uri_encoded_peer.starts_with(ipv4_prefix); + bool ipv6 = uri_encoded_peer.starts_with(ipv6_prefix); + + if (!ipv4 && !ipv6) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected ipv4 or ipv6 protocol in peer address, got {}", uri_encoded_peer); + + auto prefix = ipv4 ? ipv4_prefix : ipv6_prefix; + auto family = ipv4 ? Poco::Net::AddressFamily::Family::IPv4 : Poco::Net::AddressFamily::Family::IPv6; + + uri_encoded_peer = uri_encoded_peer.substr(prefix.length()); + + String peer; + Poco::URI::decode(uri_encoded_peer, peer); + + return Poco::Net::SocketAddress{family, peer}; + } +} + + +String AuthMiddlewareFactory::TokenStorage::getToken(std::string username, std::string password) +{ + std::lock_guard lock(token_mutex); + + cleanupExpiredTokens(); + + auto token = toString(UUIDHelpers::generateV4()); + auto expiration_time = std::chrono::steady_clock::now() + std::chrono::seconds(config.getInt("default_session_timeout", 60)); + auto exp_iter = token_expiration_list.insert({expiration_time, token}); + token_expiration_list_index[token] = exp_iter; + token_to_credentials[token] = {username, password}; + + return token; +} + +std::optional> AuthMiddlewareFactory::TokenStorage::getCredentials(std::string token) +{ + std::lock_guard lock(token_mutex); + + cleanupExpiredTokens(); + + auto it = token_to_credentials.find(token); + if (it != token_to_credentials.end()) + { + auto expiration_time = std::chrono::steady_clock::now() + std::chrono::seconds(config.getInt("default_session_timeout", 60)); + auto new_iter = token_expiration_list.insert({expiration_time, token}); + token_expiration_list.erase(token_expiration_list_index[token]); + token_expiration_list_index[token] = new_iter; + return it->second; + } + return std::nullopt; +} + +void AuthMiddlewareFactory::TokenStorage::cleanupExpiredTokens() +{ + auto now = std::chrono::steady_clock::now(); + for (auto it = token_expiration_list.begin(); it != token_expiration_list.end() && it->first <= now;) + { + token_to_credentials.erase(it->second); + token_expiration_list_index.erase(it->second); + it = token_expiration_list.erase(it); + } +} + +arrow::Status AuthMiddlewareFactory::StartCall( + const arrow::flight::CallInfo & /*info*/, + const arrow::flight::ServerCallContext & context, + std::shared_ptr * middleware) +{ + const auto & headers = context.incoming_headers(); + + std::string username("default"); + std::string password; + std::string token; + auto session = std::make_shared(server.context(), ClientInfo::Interface::ARROW_FLIGHT); + + bool auth = false; + + try + { + if (auto credentials = getCredentialsFromBasicHeader(headers)) + { + auth = true; + std::tie(username, password) = *credentials; + } + else if (auto token_opt = getTokenFromBearerHeader(headers); token_opt && *token_opt != "None") + { + token = *token_opt; + credentials = token_storage.getCredentials(token); + if (!credentials) + return arrow::flight::MakeFlightError(arrow::flight::FlightStatusCode::Unauthenticated, "Session expired or not authenticated."); + + std::tie(username, password) = *credentials; + } + session->authenticate(username, password, getClientAddress(context)); + } + catch (DB::Exception & e) + { + return arrow::flight::MakeFlightError(arrow::flight::FlightStatusCode::Unauthenticated, e.what()); + } + + try + { + std::string session_id; + auto session_it = headers.find("x-clickhouse-session-id"); + if (session_it != headers.end()) + session_id = std::string(session_it->second); + + std::string session_check; + session_it = headers.find("x-clickhouse-session-check"); + if (session_it != headers.end()) + session_check = std::string(session_it->second); + + std::string session_timeout_str; + session_it = headers.find("x-clickhouse-session-timeout"); + if (session_it != headers.end()) + session_timeout_str = std::string(session_it->second); + + unsigned session_timeout = 0; + if (!session_timeout_str.empty()) + { + ReadBufferFromString buf(session_timeout_str); + if (!tryReadIntText(session_timeout, buf) || !buf.eof()) + return arrow::Status::Invalid("Invalid session timeout: " + session_timeout_str); + } + + std::string session_close; + session_it = headers.find("x-clickhouse-session-close"); + if (session_it != headers.end()) + session_close = std::string(session_it->second); + + if (session_id.empty()) + session->makeSessionContext(); + else + session->makeSessionContext(session_id, parseSessionTimeout(server.context()->getConfigRef(), session_timeout), session_check == "1"); + + if (auth) + token = token_storage.getToken(username, password); + + *middleware = std::make_unique(session, token, username, session_id, session_close == "1" && server.config().getBool("enable_arrow_close_session", true)); + } + catch (DB::Exception & e) + { + return arrow::Status::Invalid(e.what()); + } + + return arrow::Status::OK(); +} + +} diff --git a/src/Server/ArrowFlight/AuthMiddleware.h b/src/Server/ArrowFlight/AuthMiddleware.h new file mode 100644 index 000000000000..b112d6008773 --- /dev/null +++ b/src/Server/ArrowFlight/AuthMiddleware.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include + +#include + +#include +#include + +namespace DB +{ + +inline const std::string AUTHORIZATION_HEADER = "authorization"; +inline const std::string AUTHORIZATION_MIDDLEWARE_NAME = "authorization_middleware"; + +class AuthMiddleware : public arrow::flight::ServerMiddleware +{ +public: + explicit AuthMiddleware(std::shared_ptr session_, const std::string & token_, const std::string & username_, + const std::string & session_id_ = "", bool session_close_ = false) + : session(session_) + , token(token_) + , username(username_) + , session_id(session_id_) + , session_close(session_close_) + { + } + + static AuthMiddleware & get(const arrow::flight::ServerCallContext & context) + { + return *static_cast(context.GetMiddleware(AUTHORIZATION_MIDDLEWARE_NAME)); + } + + const std::string & getUsername() const { return username; } + const std::shared_ptr & getSession() const { return session; } + + void SendingHeaders(arrow::flight::AddCallHeaders * outgoing_headers) override; + void CallCompleted(const arrow::Status & /*status*/) override; + + std::string name() const override { return AUTHORIZATION_MIDDLEWARE_NAME; } + +private: + std::shared_ptr session; + std::string token; + std::string username; + const std::string session_id; + const bool session_close; +}; + +class AuthMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory +{ + /// TokenStorage keeps track of issued tokens, check for expiration and expires them on any access, + /// updates expiration time of not expired token on request for credentials (getCredentials) + class TokenStorage + { + public: + explicit TokenStorage(const Poco::Util::AbstractConfiguration & config_) : config(config_) {} + + /// Generates unique token for given credentials and saves it in storage. + String getToken(std::string username, std::string password); + + /// Returns credential associated with specific token and updates expiration time for this token. + /// If the token isn't found (never existed or expired) - returns empty optional. + std::optional> getCredentials(std::string token); + + private: + void cleanupExpiredTokens() TSA_REQUIRES(token_mutex); + + using TokenExpirationList = std::multimap; + + std::mutex token_mutex; + TokenExpirationList token_expiration_list TSA_GUARDED_BY(token_mutex); + std::unordered_map token_expiration_list_index TSA_GUARDED_BY(token_mutex); + std::unordered_map> token_to_credentials TSA_GUARDED_BY(token_mutex); + + const Poco::Util::AbstractConfiguration & config; + }; + +public: + explicit AuthMiddlewareFactory(IServer & server_) + : server(server_) + , token_storage(server_.config()) + {} + + arrow::Status StartCall( + const arrow::flight::CallInfo & /*info*/, + const arrow::flight::ServerCallContext & context, + std::shared_ptr * middleware) override; + + private: + IServer & server; + TokenStorage token_storage; +}; + +} diff --git a/src/Server/ArrowFlight/CallsData.cpp b/src/Server/ArrowFlight/CallsData.cpp new file mode 100644 index 000000000000..2b8f6d9da1dd --- /dev/null +++ b/src/Server/ArrowFlight/CallsData.cpp @@ -0,0 +1,567 @@ +#include + +#if USE_ARROWFLIGHT + +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int QUERY_WAS_CANCELLED; +} + +namespace ArrowFlight +{ + +Timestamp CallsData::now() +{ + return std::chrono::system_clock::now(); +} + +CallsData::CallsData(std::optional tickets_lifetime_, std::optional poll_descriptors_lifetime_, LoggerPtr log_) + : tickets_lifetime(tickets_lifetime_) + , poll_descriptors_lifetime(poll_descriptors_lifetime_) + , log(log_) +{ +} + +std::shared_ptr CallsData::createTicket(std::shared_ptr arrow_table) +{ + String ticket = generateTicketName(); + LOG_DEBUG(log, "Creating ticket {}", ticket); + auto expiration_time = calculateTicketExpirationTime(now()); + auto info = std::make_shared(); + info->ticket = ticket; + info->expiration_time = expiration_time; + info->arrow_table = arrow_table; + std::lock_guard lock{mutex}; + bool inserted = tickets.try_emplace(ticket, info).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) + chassert(inserted); /// Flight tickets are unique. + if (expiration_time) + { + inserted = tickets_by_expiration_time.emplace(*expiration_time, ticket).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) + chassert(inserted); /// Flight tickets are unique. + updateNextExpirationTime(); + } + return info; +} + +arrow::Result> CallsData::getTicketInfo(const String & ticket) const +{ + std::lock_guard lock{mutex}; + auto it = tickets.find(ticket); + if (it == tickets.end()) + return arrow::Status::KeyError("Ticket ", quoteString(ticket), " not found"); + return it->second; +} + +std::optional CallsData::getTicketExpirationTime(const String & ticket) const +{ + if (!tickets_lifetime) + return std::nullopt; + std::lock_guard lock{mutex}; + auto it = tickets.find(ticket); + if (it == tickets.end()) + return ALREADY_EXPIRED; + return it->second->expiration_time; +} + +void CallsData::cancelTicket(const String & ticket) +{ + std::lock_guard lock{mutex}; + auto it = tickets.find(ticket); + if (it == tickets.end()) + return; /// The ticket has been already cancelled. + LOG_DEBUG(log, "Cancelling ticket {}", ticket); + auto info = it->second; + tickets.erase(it); + if (info->expiration_time) + { + tickets_by_expiration_time.erase(std::make_pair(*info->expiration_time, ticket)); + updateNextExpirationTime(); + } +} + +void CallsData::setFlightDescriptorMapLocked(const String & flight_descriptor, const String & query_id) +{ + flight_descriptor_to_query_id[flight_descriptor] = query_id; + query_id_to_flight_descriptors[query_id].insert(flight_descriptor); +} + +void CallsData::eraseFlightDescriptorMapByQueryIdLocked(const String & query_id) +{ + auto it = query_id_to_flight_descriptors.find(query_id); + if (it == query_id_to_flight_descriptors.end()) + return; + for (const auto & flight_descriptor : it->second) + flight_descriptor_to_query_id.erase(flight_descriptor); + query_id_to_flight_descriptors.erase(it); +} + +void CallsData::eraseFlightDescriptorMapByQueryId(const String & query_id) +{ + std::lock_guard lock{mutex}; + eraseFlightDescriptorMapByQueryIdLocked(query_id); +} + +void CallsData::eraseFlightDescriptorMapByDescriptorLocked(const String & flight_descriptor) +{ + if (!flight_descriptor_to_query_id.contains(flight_descriptor)) + return; + eraseFlightDescriptorMapByQueryIdLocked(flight_descriptor_to_query_id[flight_descriptor]); +} + +void CallsData::eraseFlightDescriptorMapByDescriptor(const String & flight_descriptor) +{ + std::lock_guard lock{mutex}; + eraseFlightDescriptorMapByDescriptorLocked(flight_descriptor); +} + +void CallsData::eraseFlightDescriptorMapEntryLocked(const String & flight_descriptor) +{ + auto it_fd = flight_descriptor_to_query_id.find(flight_descriptor); + if (it_fd == flight_descriptor_to_query_id.end()) + return; + + String query_id = it_fd->second; + flight_descriptor_to_query_id.erase(it_fd); + + auto it_q = query_id_to_flight_descriptors.find(query_id); + if (it_q == query_id_to_flight_descriptors.end()) + return; + + it_q->second.erase(flight_descriptor); + if (it_q->second.empty()) + query_id_to_flight_descriptors.erase(it_q); +} + +void CallsData::eraseFlightDescriptorMapEntry(const String & flight_descriptor) +{ + std::lock_guard lock{mutex}; + eraseFlightDescriptorMapEntryLocked(flight_descriptor); +} + +std::shared_ptr +CallsData::createPollDescriptorImpl(std::unique_ptr poll_session, std::shared_ptr previous_info, std::optional flight_descriptor, std::optional query_id) +{ + String poll_descriptor; + std::lock_guard lock{mutex}; + if (previous_info) + { + if (!previous_info->evaluated) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Adding a poll descriptor while the previous poll descriptor is not evaluated"); + if (!previous_info->next_poll_descriptor) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Adding a poll descriptor while the previous poll descriptor is final"); + poll_descriptor = *previous_info->next_poll_descriptor; + query_id = getQueryIdByFlightDescriptorLocked(previous_info->poll_descriptor); + if (!query_id) + throw Exception(ErrorCodes::QUERY_WAS_CANCELLED, + "Cannot create continuation poll descriptor: previous poll descriptor {} was expired or cancelled", + previous_info->poll_descriptor); + } + else + { + poll_descriptor = generatePollDescriptorName(); + } + LOG_DEBUG(log, "Creating poll descriptor {}", poll_descriptor); + auto current_time = now(); + auto expiration_time = calculatePollDescriptorExpirationTime(current_time); + auto info = std::make_shared(); + info->poll_descriptor = poll_descriptor; + info->expiration_time = expiration_time; + info->schema = poll_session->getSchema(); + info->previous_info = previous_info; + info->query_id = *query_id; + if (previous_info) + info->original_flight_descriptor = previous_info->original_flight_descriptor; + else + info->original_flight_descriptor = *flight_descriptor; + bool inserted = poll_descriptors.try_emplace(poll_descriptor, info).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) + chassert(inserted); /// Poll descriptors are unique. + inserted = poll_sessions.try_emplace(poll_descriptor, std::move(poll_session)).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) + chassert(inserted); /// Poll descriptors are unique. + if (expiration_time) + { + inserted = poll_descriptors_by_expiration_time.emplace(*expiration_time, poll_descriptor).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) + chassert(inserted); /// Poll descriptors are unique. + updateNextExpirationTime(); + } + if (query_id) + setFlightDescriptorMapLocked(poll_descriptor, *query_id); + return info; +} + +std::shared_ptr +CallsData::createPollDescriptor(std::unique_ptr poll_session, std::shared_ptr previous_info) +{ + return createPollDescriptorImpl(std::move(poll_session), previous_info); +} + +std::shared_ptr +CallsData::createPollDescriptor(std::unique_ptr poll_session, const arrow::flight::FlightDescriptor & flight_descriptor, const String & query_id) +{ + return createPollDescriptorImpl(std::move(poll_session), nullptr, flight_descriptor, query_id); +} + +arrow::Result> CallsData::getPollDescriptorInfo(const String & poll_descriptor) const +{ + std::lock_guard lock{mutex}; + auto it = poll_descriptors.find(poll_descriptor); + if (it == poll_descriptors.end()) + return arrow::Status::KeyError("Poll descriptor ", quoteString(poll_descriptor), " not found"); + return it->second; +} + +std::optional CallsData::getQueryIdByFlightDescriptorLocked(const String & flight_descriptor) const +{ + auto it = flight_descriptor_to_query_id.find(flight_descriptor); + if (it == flight_descriptor_to_query_id.end()) + return std::nullopt; + return it->second; +} + +std::optional CallsData::getQueryIdByFlightDescriptor(const String & flight_descriptor) const +{ + std::lock_guard lock{mutex}; + return getQueryIdByFlightDescriptorLocked(flight_descriptor); +} + +PollDescriptorWithExpirationTime CallsData::getPollDescriptorWithExpirationTime(const String & poll_descriptor) const +{ + if (!poll_descriptors_lifetime) + return PollDescriptorWithExpirationTime{.poll_descriptor = poll_descriptor, .expiration_time = std::nullopt}; + std::lock_guard lock{mutex}; + auto it = poll_descriptors.find(poll_descriptor); + if (it == poll_descriptors.end()) + return PollDescriptorWithExpirationTime{.poll_descriptor = poll_descriptor, .expiration_time = ALREADY_EXPIRED}; + return *it->second; +} + +arrow::Status CallsData::extendPollDescriptorExpirationTime(const String & poll_descriptor) +{ + if (!poll_descriptors_lifetime) + return arrow::Status::OK(); + auto current_time = now(); + std::lock_guard lock{mutex}; + auto it = poll_descriptors.find(poll_descriptor); + if (it == poll_descriptors.end()) + return arrow::Status::KeyError("Poll descriptor ", quoteString(poll_descriptor), " not found"); + auto info = it->second; + auto old_expiration_time = info->expiration_time; + auto new_expiration_time = calculatePollDescriptorExpirationTime(current_time); + auto new_info = std::make_shared(*info); + new_info->expiration_time = new_expiration_time; + it->second = new_info; + poll_descriptors_by_expiration_time.erase(std::make_pair(*old_expiration_time, poll_descriptor)); + poll_descriptors_by_expiration_time.emplace(*new_expiration_time, poll_descriptor); + updateNextExpirationTime(); + return arrow::Status::OK(); +} + +arrow::Result> CallsData::startEvaluation(const String & poll_descriptor) +{ + arrow::Result> res; + std::unique_lock lock{mutex}; + evaluation_ended.wait(lock, [&]() TSA_REQUIRES(mutex) + { + auto it = poll_descriptors.find(poll_descriptor); + if (it == poll_descriptors.end()) + { + res = arrow::Status::KeyError("Poll descriptor ", quoteString(poll_descriptor), " not found"); + return true; + } + auto info = it->second; + if (info->evaluated) + { + res = std::unique_ptr{nullptr}; + return true; + } + if (!info->evaluating) + { + auto it2 = poll_sessions.find(poll_descriptor); + if (it2 == poll_sessions.end()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Session is not attached to non-evaluated poll descriptor {}", poll_descriptor); + res = std::move(it2->second); + poll_sessions.erase(it2); + auto new_info = std::make_shared(*info); + new_info->evaluating = true; + it->second = new_info; + if (info->expiration_time) + { + poll_descriptors_by_expiration_time.erase(std::make_pair(*info->expiration_time, poll_descriptor)); + updateNextExpirationTime(); + } + return true; + } + return false; /// The poll descriptor is being evaluated in another thread, we need to wait. + }); + return res; +} + +void CallsData::endEvaluation(const String & poll_descriptor, const std::optional & ticket, UInt64 rows, UInt64 bytes, bool last) +{ + std::lock_guard lock{mutex}; + auto it = poll_descriptors.find(poll_descriptor); + if (it == poll_descriptors.end()) + { + /// The poll descriptor expired during the query execution. + evaluation_ended.notify_all(); + return; + } + + auto info = it->second; + if (info->evaluated) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Poll descriptor can't be evaluated twice"); + + auto new_info = std::make_shared(*info); + new_info->evaluating = false; + new_info->evaluated = true; + new_info->status = arrow::Status::OK(); + new_info->ticket = ticket; + new_info->rows = rows; + new_info->bytes = bytes; + if (!last) + new_info->next_poll_descriptor = generatePollDescriptorName(); + auto new_expiration_time = calculatePollDescriptorExpirationTime(now()); + new_info->expiration_time = new_expiration_time; + it->second = new_info; + if (new_expiration_time) + { + poll_descriptors_by_expiration_time.emplace(*new_expiration_time, poll_descriptor); + updateNextExpirationTime(); + } + info = new_info; + evaluation_ended.notify_all(); +} + +void CallsData::endEvaluationWithError(const String & poll_descriptor, const arrow::Status & error_status) +{ + chassert(!error_status.ok()); + std::lock_guard lock{mutex}; + auto it = poll_descriptors.find(poll_descriptor); + if (it != poll_descriptors.end()) + { + auto info = it->second; + if (!info->evaluated) + { + auto new_info = std::make_shared(*info); + new_info->evaluating = false; + new_info->evaluated = true; + new_info->status = error_status; + auto new_expiration_time = calculatePollDescriptorExpirationTime(now()); + new_info->expiration_time = new_expiration_time; + it->second = new_info; + if (new_expiration_time) + { + poll_descriptors_by_expiration_time.emplace(*new_expiration_time, poll_descriptor); + updateNextExpirationTime(); + } + info = new_info; + evaluation_ended.notify_all(); + } + } + else + { + evaluation_ended.notify_all(); + } +} + +void CallsData::cancelPollDescriptor(const String & poll_descriptor) +{ + std::unique_ptr poll_session_to_cancel; + { + std::lock_guard lock{mutex}; + auto it = poll_descriptors.find(poll_descriptor); + if (it != poll_descriptors.end()) + { + LOG_DEBUG(log, "Cancelling poll descriptor {}", poll_descriptor); + auto info = it->second; + poll_descriptors.erase(it); + if (info->expiration_time) + { + poll_descriptors_by_expiration_time.erase(std::make_pair(*info->expiration_time, poll_descriptor)); + updateNextExpirationTime(); + } + } + auto it2 = poll_sessions.find(poll_descriptor); + if (it2 != poll_sessions.end()) + { + poll_session_to_cancel = std::move(it2->second); + poll_sessions.erase(it2); + } + eraseFlightDescriptorMapEntryLocked(poll_descriptor); + evaluation_ended.notify_all(); + } + + if (poll_session_to_cancel) + { + try + { + poll_session_to_cancel->onCancelOrConnectionLoss(); + } + catch (...) + { + tryLogCurrentException(log, "cancelPollDescriptor: block_io.onCancelOrConnectionLoss failed"); + } + } +} + +void CallsData::cancelExpired() +{ + std::vector> poll_sessions_to_cancel; + auto current_time = now(); + { + std::lock_guard lock{mutex}; + while (!tickets_by_expiration_time.empty()) + { + auto it = tickets_by_expiration_time.begin(); + if (current_time <= it->first) + break; + LOG_DEBUG(log, "Cancelling expired ticket {}", it->second); + tickets.erase(it->second); + tickets_by_expiration_time.erase(it); + } + + for (auto it = poll_descriptors_by_expiration_time.begin(); it != poll_descriptors_by_expiration_time.end();) + { + if (current_time <= it->first) + break; + + auto pd_it = poll_descriptors.find(it->second); + chassert(pd_it != poll_descriptors.end()); + if (pd_it == poll_descriptors.end()) + { + LOG_WARNING(log, "Poll descriptor {} found in expiration index but not in poll_descriptors; removing stale expiration entry", it->second); + it = poll_descriptors_by_expiration_time.erase(it); + continue; + } + + chassert(!pd_it->second->evaluating); + if (pd_it->second->evaluating) + { + ++it; + continue; + } + + LOG_DEBUG(log, "Cancelling expired poll descriptor {}", it->second); + poll_descriptors.erase(pd_it); + auto it2 = poll_sessions.find(it->second); + if (it2 != poll_sessions.end()) + { + poll_sessions_to_cancel.emplace_back(std::move(it2->second)); + poll_sessions.erase(it2); + } + eraseFlightDescriptorMapEntryLocked(it->second); + it = poll_descriptors_by_expiration_time.erase(it); + } + updateNextExpirationTime(); + } + + for (auto & session : poll_sessions_to_cancel) + { + if (!session) + continue; + + try + { + session->onCancelOrConnectionLoss(); + } + catch (...) + { + tryLogCurrentException(log, "cancelExpired: block_io.onCancelOrConnectionLoss failed"); + } + } +} + +std::vector CallsData::collectPollDescriptorsForQueryId(const String & query_id) const +{ + std::lock_guard lock{mutex}; + auto it = query_id_to_flight_descriptors.find(query_id); + if (it == query_id_to_flight_descriptors.end()) + return {}; + return {it->second.begin(), it->second.end()}; +} + +/// TSA_NO_THREAD_SAFETY_ANALYSIS because TSA doesn't support std::unique_lock used with condition_variable. +void CallsData::waitNextExpirationTime() const TSA_NO_THREAD_SAFETY_ANALYSIS +{ + auto current_time = now(); + std::unique_lock lock{mutex}; + auto expiration_time = next_expiration_time; + /// TSA_NO_THREAD_SAFETY_ANALYSIS because the mutex is held by the enclosing unique_lock, but TSA can't see that inside a lambda. + auto is_ready = [&]() TSA_NO_THREAD_SAFETY_ANALYSIS + { + if (stop_waiting_next_expiration_time) + return true; + if (next_expiration_time != expiration_time) + return true; /// We need to restart waiting if the next expiration time has changed. + current_time = now(); + return (expiration_time && (current_time > *expiration_time)); + }; + if (expiration_time) + { + if (current_time < *expiration_time) + next_expiration_time_updated.wait_for(lock, *expiration_time - current_time, is_ready); + } + else + { + next_expiration_time_updated.wait(lock, is_ready); + } +} + +void CallsData::stopWaitingNextExpirationTime() +{ + std::lock_guard lock{mutex}; + stop_waiting_next_expiration_time = true; + next_expiration_time_updated.notify_all(); +} + +String CallsData::generateTicketName() +{ + return TICKET_PREFIX + toString(UUIDHelpers::generateV4()); +} + +String CallsData::generatePollDescriptorName() +{ + return POLL_DESCRIPTOR_PREFIX + toString(UUIDHelpers::generateV4()); +} + +std::optional CallsData::calculateTicketExpirationTime(Timestamp current_time) const +{ + if (!tickets_lifetime) + return std::nullopt; + return current_time + *tickets_lifetime; +} + +std::optional CallsData::calculatePollDescriptorExpirationTime(Timestamp current_time) const +{ + if (!poll_descriptors_lifetime) + return std::nullopt; + return current_time + *poll_descriptors_lifetime; +} + +void CallsData::updateNextExpirationTime() +{ + auto expiration_time = next_expiration_time; + next_expiration_time.reset(); + if (!tickets_by_expiration_time.empty()) + next_expiration_time = tickets_by_expiration_time.begin()->first; + if (!poll_descriptors_by_expiration_time.empty()) + { + auto other_expiration_time = poll_descriptors_by_expiration_time.begin()->first; + next_expiration_time = next_expiration_time ? std::min(*next_expiration_time, other_expiration_time) : other_expiration_time; + } + if (next_expiration_time != expiration_time) + next_expiration_time_updated.notify_all(); +} + +} +} + +#endif diff --git a/src/Server/ArrowFlight/CallsData.h b/src/Server/ArrowFlight/CallsData.h new file mode 100644 index 000000000000..28609ccb4933 --- /dev/null +++ b/src/Server/ArrowFlight/CallsData.h @@ -0,0 +1,222 @@ +#pragma once + +#include "config.h" + +#if USE_ARROWFLIGHT + +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +namespace DB::ArrowFlight +{ + +using Timestamp = std::chrono::system_clock::time_point; +using Duration = std::chrono::system_clock::duration; + +/// We use the ALREADY_EXPIRED timestamp (January 1, 1970) as the expiration time of a ticket or a poll descriptor +/// which is already expired. +inline const Timestamp ALREADY_EXPIRED = Timestamp{Duration{0}}; + +/// We generate tickets with this prefix. +/// Method DoGet() accepts a ticket which is either 1) a ticket with this prefix; or 2) a SQL query. +/// A valid SQL query can't start with this prefix so method DoGet() can distinguish those cases. +inline const String TICKET_PREFIX = "~TICKET-"; + +inline bool hasTicketPrefix(const String & ticket) +{ + return ticket.starts_with(TICKET_PREFIX); +} + +/// We generate poll descriptors with this prefix. +/// Methods PollFlightInfo() or GetSchema() accept a flight descriptor which is either +/// 1) a normal flight descriptor (a table name or a SQL query); or 2) a poll descriptor with this prefix. +/// A valid SQL query can't start with this prefix so methods PollFlightInfo() and GetSchema() can distinguish those cases. +inline const String POLL_DESCRIPTOR_PREFIX = "~POLL-"; + +inline bool hasPollDescriptorPrefix(const String & poll_descriptor) +{ + return poll_descriptor.starts_with(POLL_DESCRIPTOR_PREFIX); +} + +/// A ticket name with its expiration time. +struct TicketWithExpirationTime +{ + String ticket; + /// When the ticket expires. + /// std::nullopt means that the ticket expires after using it in DoGet(). + /// Can be equal to ALREADY_EXPIRED. + std::optional expiration_time; +}; + +/// A poll descriptor's name with its expiration time. +struct PollDescriptorWithExpirationTime +{ + String poll_descriptor; + /// When the poll descriptor expires. + /// std::nullopt means that the poll descriptor expires after using it in PollFlightInfo(); + /// Can be equal to ALREADY_EXPIRED. + std::optional expiration_time; +}; + +struct TicketInfo : public TicketWithExpirationTime +{ + std::shared_ptr arrow_table; +}; + +/// Information about a poll descriptor. +/// Objects of type PollDescriptorInfo are stored as a kind of a doubly linked list, +/// the previous object is stored as `previous_info`, and the next object is referenced by `next_poll_descriptor`. +struct PollDescriptorInfo : public PollDescriptorWithExpirationTime +{ + std::shared_ptr schema; + std::shared_ptr previous_info; + bool evaluating = false; + bool evaluated = false; + + arrow::flight::FlightDescriptor original_flight_descriptor; + std::string query_id; + + /// The following fields can be set only if `evaluated == true`: + + /// A success or error error. + std::optional status; + + /// A new ticket. Along with tickets from previous infos (previous_info, previous_info->previous_info, etc.) + /// represents all tickets associated with this poll descriptor. + /// Can be unset if there is no block; or it can specify an already expired ticket. + std::optional ticket; + + /// Adds rows. Along with added rows from previous infos (previous_info, previous_info->previous_info, etc.) + /// represents the total number of rows associated with this poll descriptor. + /// Can be unset if there is no rows added. + std::optional rows; + + /// Adds bytes. Along with added bytes from previous infos (previous_info, previous_info->previous_info, etc.) + /// represents the total number of bytes associated with this poll descriptor. + /// Can be unset if there is no bytes added. + std::optional bytes; + + /// Next poll descriptor if any. + /// Can be unset if there is no next poll descriptor (no more blocks are to pull from the query pipeline). + std::optional next_poll_descriptor; +}; + +/// Keeps information about calls - e.g. blocks extracted from query pipelines, flight tickets, poll descriptors. +class CallsData +{ +public: + CallsData(std::optional tickets_lifetime_, std::optional poll_descriptors_lifetime_, LoggerPtr log_); + + /// Creates a flight ticket which allows to download a specified block. + std::shared_ptr createTicket(std::shared_ptr arrow_table); + + [[nodiscard]] arrow::Result> getTicketInfo(const String & ticket) const; + + /// Finds the expiration time for a specified ticket. + /// If the ticket is not found it means it was expired and removed from the map. + std::optional getTicketExpirationTime(const String & ticket) const; + + /// Cancels a ticket to free memory. + void cancelTicket(const String & ticket); + + void eraseFlightDescriptorMapByQueryId(const String & query_id); + void eraseFlightDescriptorMapByDescriptor(const String & flight_descriptor); + void eraseFlightDescriptorMapEntry(const String & flight_descriptor); + + /// Creates a poll descriptor. + std::shared_ptr + createPollDescriptor(std::unique_ptr poll_session, std::shared_ptr previous_info); + + std::shared_ptr + createPollDescriptor(std::unique_ptr poll_session, const arrow::flight::FlightDescriptor & flight_descriptor, const String & query_id); + + [[nodiscard]] arrow::Result> getPollDescriptorInfo(const String & poll_descriptor) const; + + /// Finds query id for a specified flight descriptor. + std::optional getQueryIdByFlightDescriptor(const String & flight_descriptor) const; + + /// Finds the expiration time for a specified poll descriptor. + PollDescriptorWithExpirationTime getPollDescriptorWithExpirationTime(const String & poll_descriptor) const; + + /// Extends the expiration time of a poll descriptor. + [[nodiscard]] arrow::Status extendPollDescriptorExpirationTime(const String & poll_descriptor); + + /// Starts evaluation (i.e. getting a block of data) for a specified poll descriptor. + [[nodiscard]] arrow::Result> startEvaluation(const String & poll_descriptor); + + /// Ends evaluation for a specified poll descriptor. + void endEvaluation(const String & poll_descriptor, const std::optional & ticket, UInt64 rows, UInt64 bytes, bool last); + + /// Ends evaluation for a specified poll descriptor with an error. + void endEvaluationWithError(const String & poll_descriptor, const arrow::Status & error_status); + + /// Cancels a poll descriptor to free memory. + void cancelPollDescriptor(const String & poll_descriptor); + + /// Cancels tickets and poll descriptors if the current time is greater than their expiration time. + void cancelExpired(); + + std::vector collectPollDescriptorsForQueryId(const String & query_id) const; + + /// Waits until maybe it's time to cancel expired tickets or poll descriptors. + /// TSA_NO_THREAD_SAFETY_ANALYSIS because TSA doesn't support std::unique_lock used with condition_variable. + void waitNextExpirationTime() const TSA_NO_THREAD_SAFETY_ANALYSIS; + + void stopWaitingNextExpirationTime(); + +private: + static String generateTicketName(); + static String generatePollDescriptorName(); + + std::optional calculateTicketExpirationTime(Timestamp current_time) const; + std::optional calculatePollDescriptorExpirationTime(Timestamp current_time) const; + + void updateNextExpirationTime() TSA_REQUIRES(mutex); + + void setFlightDescriptorMapLocked(const String & flight_descriptor, const String & query_id) TSA_REQUIRES(mutex); + void eraseFlightDescriptorMapByQueryIdLocked(const String & query_id) TSA_REQUIRES(mutex); + void eraseFlightDescriptorMapByDescriptorLocked(const String & flight_descriptor) TSA_REQUIRES(mutex); + void eraseFlightDescriptorMapEntryLocked(const String & flight_descriptor) TSA_REQUIRES(mutex); + std::optional getQueryIdByFlightDescriptorLocked(const String & flight_descriptor) const TSA_REQUIRES(mutex); + + std::shared_ptr + createPollDescriptorImpl(std::unique_ptr poll_session, std::shared_ptr previous_info, std::optional flight_descriptor = std::nullopt, std::optional query_id = std::nullopt); + + static Timestamp now(); + + const std::optional tickets_lifetime; + const std::optional poll_descriptors_lifetime; + const LoggerPtr log; + mutable std::mutex mutex; + std::unordered_map> tickets TSA_GUARDED_BY(mutex); + std::unordered_map> poll_descriptors TSA_GUARDED_BY(mutex); + std::unordered_map> poll_sessions TSA_GUARDED_BY(mutex); + std::condition_variable evaluation_ended; + /// associates flight descriptors with query id + std::unordered_map flight_descriptor_to_query_id TSA_GUARDED_BY(mutex); + std::unordered_map> query_id_to_flight_descriptors TSA_GUARDED_BY(mutex); + /// `tickets_by_expiration_time` and `poll_descriptors_by_expiration_time` are sorted by `expiration_time` so `std::set` is used. + std::set> tickets_by_expiration_time TSA_GUARDED_BY(mutex); + std::set> poll_descriptors_by_expiration_time TSA_GUARDED_BY(mutex); + std::optional next_expiration_time TSA_GUARDED_BY(mutex); + mutable std::condition_variable next_expiration_time_updated; + bool stop_waiting_next_expiration_time TSA_GUARDED_BY(mutex) = false; +}; + +} + +#endif diff --git a/src/Server/ArrowFlight/PollSession.cpp b/src/Server/ArrowFlight/PollSession.cpp new file mode 100644 index 000000000000..50b2d5518f06 --- /dev/null +++ b/src/Server/ArrowFlight/PollSession.cpp @@ -0,0 +1,91 @@ +#include + +#if USE_ARROWFLIGHT + +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int UNKNOWN_EXCEPTION; +} + +namespace Setting +{ + extern const SettingsBool output_format_arrow_unsupported_types_as_binary; +} + +namespace ArrowFlight +{ + +PollSession::PollSession( + ContextPtr query_context_, + ThreadGroupPtr thread_group_, + BlockIO && block_io_, + SchemaModifier schema_modifier, + BlockModifier block_modifier_) + : query_context(query_context_) + , thread_group(thread_group_) + , block_io(std::move(block_io_)) + , block_modifier(block_modifier_) +{ + try + { + executor.emplace(block_io.pipeline); + schema = CHColumnToArrowColumn::calculateArrowSchema( + executor->getHeader().getColumnsWithTypeAndName(), + "Arrow", + nullptr, + {.output_string_as_string = true, .output_unsupported_types_as_binary = query_context->getSettingsRef()[Setting::output_format_arrow_unsupported_types_as_binary]}); + + if (schema_modifier) + { + auto result = schema_modifier(schema); + if (!result.ok()) + throw Exception(ErrorCodes::UNKNOWN_EXCEPTION, "Failed to convert Arrow schema: {} (schema: {})", result.status().ToString(), schema->ToString()); + schema = result.ValueUnsafe(); + } + } + catch (...) + { + try { block_io.onException(); } + catch (...) { tryLogCurrentException("PollSession: block_io.onException() failed during constructor rollback"); } + throw; + } +} + +PollSession::~PollSession() = default; + +ContextPtr PollSession::queryContext() { return query_context; } + +ThreadGroupPtr PollSession::getThreadGroup() const { return thread_group; } + +std::shared_ptr PollSession::getSchema() const { return schema; } + +bool PollSession::getNextBlock(Block & block) +{ + if (!executor->pull(block)) + return false; + if (block_modifier) + block_modifier(query_context, block); + return true; +} + +void PollSession::onFinish() { block_io.onFinish(); } + +void PollSession::onException() { block_io.onException(); } + +void PollSession::onCancelOrConnectionLoss() { block_io.onCancelOrConnectionLoss(); } + +} +} + +#endif diff --git a/src/Server/ArrowFlight/PollSession.h b/src/Server/ArrowFlight/PollSession.h new file mode 100644 index 000000000000..d52762816edd --- /dev/null +++ b/src/Server/ArrowFlight/PollSession.h @@ -0,0 +1,59 @@ +#pragma once + +#include "config.h" + +#if USE_ARROWFLIGHT + +#include + +#include +#include +#include +#include + +#include + +#include + + +namespace DB +{ + +namespace ArrowFlight +{ + +/// Keeps a query context and a pipeline executor for PollFlightInfo. +class PollSession +{ +public: + PollSession( + ContextPtr query_context_, + ThreadGroupPtr thread_group_, + BlockIO && block_io_, + SchemaModifier schema_modifier = nullptr, + BlockModifier block_modifier_ = nullptr); + + ~PollSession(); + + ContextPtr queryContext(); + + ThreadGroupPtr getThreadGroup() const; + std::shared_ptr getSchema() const; + bool getNextBlock(Block & block); + void onFinish(); + void onException(); + void onCancelOrConnectionLoss(); + +private: + ContextPtr query_context; + ThreadGroupPtr thread_group; + BlockIO block_io; + std::optional executor; + std::shared_ptr schema; + BlockModifier block_modifier; +}; + +} +} + +#endif diff --git a/src/Server/ArrowFlight/commandSelector.cpp b/src/Server/ArrowFlight/commandSelector.cpp new file mode 100644 index 000000000000..89c3308347b6 --- /dev/null +++ b/src/Server/ArrowFlight/commandSelector.cpp @@ -0,0 +1,705 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +namespace Setting +{ + extern const SettingsBool output_format_arrow_unsupported_types_as_binary; +} + +namespace ArrowFlight +{ + +static arrow::Result> commandGetSqlInfo(const arrow::flight::protocol::sql::CommandGetSqlInfo & command, bool schema_only) +{ + arrow::MemoryPool* pool = arrow::default_memory_pool(); + + auto string_builder = std::make_shared(); + auto boolean_builder = std::make_shared(); + auto int64_builder = std::make_shared(); + auto int32_builder = std::make_shared(); + + // string_list: list not null + auto string_list_type = arrow::list(arrow::utf8()); + auto string_list_builder = std::make_shared(pool, std::make_shared(), string_list_type); + + // int32_to_int32_list_map: map> not null + auto value_type = arrow::list(arrow::int32()); + auto value_builder = std::make_shared(pool, std::make_shared(), value_type); + auto int32_to_int32_list_map_type = arrow::map(arrow::int32(), value_type); + auto int32_to_int32_list_map_builder = std::make_shared(pool, std::make_shared(), value_builder, int32_to_int32_list_map_type); + + // dense_union + auto dense_union_type = arrow::dense_union( + { + std::make_shared("string_value", arrow::utf8(), false), + std::make_shared("bool_value", arrow::boolean(), false), + std::make_shared("bigint_value", arrow::int64(), false), + std::make_shared("int32_bitmask", arrow::int32(), false), + std::make_shared("string_list", string_list_type, false), + std::make_shared("int32_to_int32_list_map", int32_to_int32_list_map_type, false) + }); + + auto dense_union_builder = std::make_shared( + pool, + std::vector>{ + string_builder, + boolean_builder, + int64_builder, + int32_builder, + string_list_builder, + int32_to_int32_list_map_builder + }, + dense_union_type); + + using SqlInfo = arrow::flight::protocol::sql::SqlInfo; + + auto info_name_builder = std::make_shared(); + + static const size_t SQL_INFO_STRING = 0; + static const size_t SQL_INFO_BOOLEAN = 1; + static const size_t SQL_INFO_INT64 = 2; + static const size_t SQL_INFO_INT32 = 3; + + auto builder_string_append = [&](auto i, const std::string & v) + { + ARROW_RETURN_NOT_OK(info_name_builder->Append(i)); + ARROW_RETURN_NOT_OK(dense_union_builder->Append(SQL_INFO_STRING)); + return string_builder->Append(v); + }; + + auto builder_boolean_append = [&](auto i, bool v) + { + ARROW_RETURN_NOT_OK(info_name_builder->Append(i)); + ARROW_RETURN_NOT_OK(dense_union_builder->Append(SQL_INFO_BOOLEAN)); + return boolean_builder->Append(v); + }; + + [[maybe_unused]] auto builder_int64_append = [&](auto i, int64_t v) + { + ARROW_RETURN_NOT_OK(info_name_builder->Append(i)); + ARROW_RETURN_NOT_OK(dense_union_builder->Append(SQL_INFO_INT64)); + return int64_builder->Append(v); + }; + + auto builder_int32_append = [&](auto i, int32_t v) + { + ARROW_RETURN_NOT_OK(info_name_builder->Append(i)); + ARROW_RETURN_NOT_OK(dense_union_builder->Append(SQL_INFO_INT32)); + return int32_builder->Append(v); + }; + + if (!schema_only) + { + #define SQL_INFO_SELECTOR(INFO_NAME, BUILDER, ARG) \ + { INFO_NAME, [&](){ return BUILDER(INFO_NAME, ARG); } } + + std::unordered_map> selector + { + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_NAME, builder_string_append, "ClickHouse"), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_VERSION, builder_string_append, VERSION_STRING), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION, builder_string_append, ARROW_VERSION_STRING), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, builder_boolean_append, false), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_SQL, builder_boolean_append, true), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, builder_boolean_append, false), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION, builder_string_append, ""), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION, builder_string_append, ""), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, builder_int32_append, arrow::flight::protocol::sql::SQL_SUPPORTED_TRANSACTION_NONE), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_CANCEL, builder_boolean_append, true), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT, builder_int32_append, 0), + SQL_INFO_SELECTOR(SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT, builder_int32_append, 0) + }; + #undef SQL_INFO_SELECTOR + + if (command.info().empty()) + { + for (const auto & [_, builder] : selector) + ARROW_RETURN_NOT_OK(builder()); + } + else + { + for (const auto & info_name : command.info()) + if (auto it = selector.find(static_cast(info_name)); it != selector.end()) + ARROW_RETURN_NOT_OK(it->second()); + } + } + + // Schema for table + std::shared_ptr table_schema = arrow::schema({ + arrow::field("info_name", arrow::uint32()), + arrow::field("value", dense_union_type) + }); + + auto info_name = info_name_builder->Finish(); + ARROW_RETURN_NOT_OK(info_name); + + auto value = dense_union_builder->Finish(); + ARROW_RETURN_NOT_OK(value); + + return arrow::Table::Make(table_schema, {info_name.ValueUnsafe(), value.ValueUnsafe()}); +} + +static SQLSet commandGetCatalogs() +{ + return {"SELECT '' AS catalog_name FROM numbers(0)", {}, {}}; +} + +static SQLSet commandGetDbSchemas(const arrow::flight::protocol::sql::CommandGetDbSchemas & command) +{ + std::string where_expression; + if (command.has_db_schema_filter_pattern()) + where_expression = " WHERE database LIKE " + quoteString(command.db_schema_filter_pattern()); + + return {"SELECT NULL::Nullable(String) AS catalog_name, name AS db_schema_name FROM system.databases" + where_expression, {}, {}}; +} + +/// Splits a formatted expression list (e.g. from `system.tables.primary_key`) into +/// individual expressions. Unlike `splitByChar(',', ...)`, this correctly handles commas +/// inside parenthesized function arguments, square-bracketed array subscripts, +/// single-quoted string literals, and backtick-quoted identifiers. +static std::vector splitExpressionList(std::string_view s) +{ + std::vector result; + int depth = 0; + bool in_single_quote = false; + bool in_backtick = false; + bool escape_next = false; + size_t expr_start = 0; + + auto flush = [&](size_t end) + { + auto token = s.substr(expr_start, end - expr_start); + size_t first = token.find_first_not_of(' '); + size_t last = token.find_last_not_of(' '); + if (first != std::string_view::npos) + result.emplace_back(token.substr(first, last - first + 1)); + }; + + for (size_t i = 0; i < s.size(); ++i) + { + char c = s[i]; + + if (escape_next) + { + escape_next = false; + continue; + } + + if (c == '\\' && (in_single_quote || in_backtick)) + { + escape_next = true; + continue; + } + + if (c == '\'' && !in_backtick) + { + in_single_quote = !in_single_quote; + continue; + } + + if (c == '`' && !in_single_quote) + { + in_backtick = !in_backtick; + continue; + } + + if (in_single_quote || in_backtick) + continue; + + if (c == '(' || c == '[') + ++depth; + else if (c == ')' || c == ']') + --depth; + else if (c == ',' && depth == 0) + { + flush(i); + expr_start = i + 1; + } + } + + flush(s.size()); + return result; +} + +static SQLSet commandGetPrimaryKeys(const arrow::flight::protocol::sql::CommandGetPrimaryKeys & command) +{ + std::string where_expression = " WHERE" + + (command.has_db_schema() ? (" database = " + quoteString(command.db_schema()) + " AND") : "") + + " name = " + quoteString(command.table()); + + auto sql = + "SELECT " + "materialize(NULL::Nullable(String)) AS catalog_name, " + "database AS schema_name, " + "name AS table_name, " + "primary_key AS column_name, " + "materialize(0::Int32) AS key_seq, " + "materialize(NULL::Nullable(String)) AS pk_name " + "FROM system.tables" + + where_expression; + + auto block_modifier = [](ContextPtr, Block & block) + { + size_t num_rows = block.rows(); + if (num_rows == 0) + return; + + const size_t column_name_pos = 3; + const size_t key_seq_pos = 4; + const size_t num_columns = block.columns(); + + auto & pk_column = block.getByPosition(column_name_pos); + auto pk_col = pk_column.column->convertToFullIfNeeded(); + + std::vector new_columns; + for (size_t col = 0; col < num_columns; ++col) + new_columns.push_back(block.getByPosition(col).column->cloneEmpty()); + + for (size_t i = 0; i < num_rows; ++i) + { + auto expressions = splitExpressionList(pk_col->getDataAt(i)); + + Int32 key_seq = 1; + for (const auto & expr : expressions) + { + for (size_t col = 0; col < num_columns; ++col) + { + if (col == column_name_pos) + new_columns[col]->insert(expr); + else if (col == key_seq_pos) + new_columns[col]->insert(key_seq); + else + new_columns[col]->insertFrom(*block.getByPosition(col).column, i); + } + ++key_seq; + } + } + + block.setColumns(std::move(new_columns)); + }; + + return {sql, {}, block_modifier}; +} + +const static std::vector> engine_to_type = +{ + // Log tables + {"Log", "LOG TABLE"}, + {"StripeLog", "LOG TABLE"}, + {"TinyLog", "LOG TABLE"}, + + // Memory tables + {"Buffer", "MEMORY TABLE"}, + {"Memory", "MEMORY TABLE"}, + {"Set", "MEMORY TABLE"}, + + // Views + {"View", "VIEW"}, + {"LiveView", "VIEW"}, + {"MaterializedView", "MATERIALIZED VIEW"}, + {"WindowView", "VIEW"}, + + // Dictionary + {"Dictionary", "DICTIONARY"}, + + // Remote/External tables + {"AzureBlobStorage", "REMOTE TABLE"}, + {"AzureQueue", "REMOTE TABLE"}, + {"ArrowFlight", "REMOTE TABLE"}, + {"DeltaLake", "REMOTE TABLE"}, + {"DeltaLakeAzure", "REMOTE TABLE"}, + {"DeltaLakeLocal", "REMOTE TABLE"}, + {"DeltaLakeS3", "REMOTE TABLE"}, + {"Distributed", "REMOTE TABLE"}, + {"GCS", "REMOTE TABLE"}, + {"HDFS", "REMOTE TABLE"}, + {"Hive", "REMOTE TABLE"}, + {"Hudi", "REMOTE TABLE"}, + {"Iceberg", "REMOTE TABLE"}, + {"IcebergAzure", "REMOTE TABLE"}, + {"IcebergHDFS", "REMOTE TABLE"}, + {"IcebergLocal", "REMOTE TABLE"}, + {"IcebergS3", "REMOTE TABLE"}, + {"JDBC", "REMOTE TABLE"}, + {"Kafka", "REMOTE TABLE"}, + {"MaterializedPostgreSQL", "REMOTE TABLE"}, + {"MongoDB", "REMOTE TABLE"}, + {"MySQL", "REMOTE TABLE"}, + {"NATS", "REMOTE TABLE"}, + {"ODBC", "REMOTE TABLE"}, + {"OSS", "REMOTE TABLE"}, + {"PostgreSQL", "REMOTE TABLE"}, + {"RabbitMQ", "REMOTE TABLE"}, + {"Redis", "REMOTE TABLE"}, + {"S3", "REMOTE TABLE"}, + {"S3Queue", "REMOTE TABLE"}, + {"URL", "REMOTE TABLE"}, + {"YTsaurus", "REMOTE TABLE"}, + + // Regular tables (MergeTree family and others) + {"AggregatingMergeTree", "TABLE"}, + {"Alias", "TABLE"}, + {"CoalescingMergeTree", "TABLE"}, + {"CollapsingMergeTree", "TABLE"}, + {"EmbeddedRocksDB", "TABLE"}, + {"Executable", "TABLE"}, + {"ExecutablePool", "TABLE"}, + {"GraphiteMergeTree", "TABLE"}, + {"Join", "TABLE"}, + {"KeeperMap", "TABLE"}, + {"Merge", "TABLE"}, + {"MergeTree", "TABLE"}, + {"ReplacingMergeTree", "TABLE"}, + {"ReplicatedAggregatingMergeTree", "TABLE"}, + {"ReplicatedCoalescingMergeTree", "TABLE"}, + {"ReplicatedCollapsingMergeTree", "TABLE"}, + {"ReplicatedGraphiteMergeTree", "TABLE"}, + {"ReplicatedMergeTree", "TABLE"}, + {"ReplicatedReplacingMergeTree", "TABLE"}, + {"ReplicatedSummingMergeTree", "TABLE"}, + {"ReplicatedVersionedCollapsingMergeTree", "TABLE"}, + {"SummingMergeTree", "TABLE"}, + {"VersionedCollapsingMergeTree", "TABLE"}, + {"COSN", "TABLE"}, + {"SharedAggregatingMergeTree", "TABLE"}, + {"SharedCoalescingMergeTree", "TABLE"}, + {"SharedCollapsingMergeTree", "TABLE"}, + {"SharedGraphiteMergeTree", "TABLE"}, + {"SharedJoin", "TABLE"}, + {"SharedMergeTree", "TABLE"}, + {"SharedReplacingMergeTree", "TABLE"}, + {"SharedSet", "TABLE"}, + {"SharedSummingMergeTree", "TABLE"}, + {"SharedVersionedCollapsingMergeTree", "TABLE"}, + + // Special + {"TimeSeries", "TABLE"}, + {"Null", "TABLE"}, + {"Loop", "TABLE"}, + {"SQLite", "TABLE"}, + {"File", "TABLE"}, + {"FileLog", "TABLE"}, + {"GenerateRandom", "TABLE"}, + {"FuzzJSON", "TABLE"}, + {"FuzzQuery", "TABLE"}, +}; + +const static std::string & getTableTypeMap() +{ + const static auto res = []() + { + auto args = std::ranges::fold_left(engine_to_type, std::pair(), + [](auto acc, const auto & val) + { + if (!acc.first.empty()) + { + acc.first += ", "; + acc.second += ", "; + } + acc.first += "'" + val.first + "'"; + acc.second += "'" + val.second + "'"; + return acc; + } + ); + return "[" + args.first + "], [" + args.second + "]"; + }(); + + return res; +} + +static SQLSet commandGetTables(const arrow::flight::protocol::sql::CommandGetTables & command) +{ + std::vector where; + if (command.has_db_schema_filter_pattern()) + where.push_back("db_schema_name LIKE " + quoteString(command.db_schema_filter_pattern())); + if (command.has_table_name_filter_pattern()) + where.push_back("table_name LIKE " + quoteString(command.table_name_filter_pattern())); + if (command.table_types_size()) + { + where.push_back( + "table_type IN [" + + boost::algorithm::join( + command.table_types() + | boost::adaptors::transformed([](const auto & table_type) { return quoteString(table_type); }), + ", ") + + "]" + ); + } + auto where_expression = where.empty() ? "" : " WHERE " + boost::algorithm::join(where, " AND "); + + if (!command.include_schema()) + { + return { + "SELECT " + "catalog_name, " + "db_schema_name, " + "table_name, " + "table_type " + "FROM (" + "SELECT " + "NULL::Nullable(String) AS catalog_name, " + "database::Nullable(String) AS db_schema_name, " + "table AS table_name, " + "transform(engine, " + getTableTypeMap() + ", 'UNKNOWN TABLE TYPE') AS table_type " + "FROM system.tables" + ")" + + where_expression, + {}, + {} + }; + } + + auto sql = + "SELECT " + "catalog_name, " + "db_schema_name, " + "table_name, " + "table_type, " + "table_schema " + "FROM (" + "SELECT " + "NULL::Nullable(String) AS catalog_name, " + "left.database::Nullable(String) AS db_schema_name, " + "left.table AS table_name, " + "transform(left.engine, " + getTableTypeMap() + ", 'UNKNOWN TABLE TYPE') AS table_type, " + "ifNull(right.table_schema, CAST([], 'Array(Tuple(String, String))')) AS table_schema " + "FROM system.tables AS left " + "LEFT JOIN " + "(" + "SELECT " + "database, " + "table, " + "arraySort((x, y) -> y, groupArray((name, type)), groupArray(position)) AS table_schema " + "FROM system.columns " + "GROUP BY " + "database, " + "table" + ") AS right ON left.database = right.database AND left.table = right.table" + ")" + + where_expression; + + auto schema_modifier = [](std::shared_ptr table_schema) + { + const auto & table_schema_field = table_schema->field(4); + return table_schema->SetField(4, std::make_shared(table_schema_field->name(), arrow::binary(), table_schema_field->nullable())); + }; + + auto block_modifier = [](ContextPtr query_context, Block & block) + { + const size_t table_schema_pos = 4; + const auto & table_schema_column = block.getByPosition(table_schema_pos); + auto new_column = ColumnString::create(); + auto col = table_schema_column.column->convertToFullIfNeeded(); + const auto & arr = typeid_cast(*col); + const auto & tuple_col = typeid_cast(arr.getData()); + const auto & name_col = typeid_cast(tuple_col.getColumn(0)); + const auto & type_col = typeid_cast(tuple_col.getColumn(1)); + for (size_t i = 0; i < col->size(); ++i) + { + ColumnsWithTypeAndName table_columns; + auto start = i ? arr.getOffsets()[i - 1] : 0; + auto end = arr.getOffsets()[i]; + for (size_t j = 0; j < end - start; ++j) + { + const auto name = name_col.getDataAt(start + j); + const auto type = type_col.getDataAt(start + j); + + auto data_type = DataTypeFactory::instance().get(String(type)); + table_columns.emplace_back(nullptr, data_type, String(name)); + } + auto table_schema = CHColumnToArrowColumn::calculateArrowSchema( + table_columns, "Arrow", nullptr, + {.output_string_as_string = true, .output_unsupported_types_as_binary = query_context->getSettingsRef()[Setting::output_format_arrow_unsupported_types_as_binary]}); + auto serialized_res = arrow::ipc::SerializeSchema(*table_schema, arrow::default_memory_pool()); + if (!serialized_res.ok()) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Failed to serialize Arrow schema: {}", serialized_res.status().ToString()); + const auto & serialized_buffer = serialized_res.ValueUnsafe(); + new_column->insertData(reinterpret_cast(serialized_buffer->data()), serialized_buffer->size()); + } + + block.getByPosition(table_schema_pos) = ColumnWithTypeAndName( + std::move(new_column), + std::make_shared(), + table_schema_column.name); + }; + + return {sql, schema_modifier, block_modifier}; +} + +static SQLSet commandGetTableTypes() +{ + return {"SELECT DISTINCT transform(name, " + getTableTypeMap() + ", 'UNKNOWN TABLE TYPE') AS table_type FROM system.table_engines", {}, {}}; +} + +static CommandSelectorResult commandStatementQuery(const arrow::flight::protocol::sql::CommandStatementQuery & command) +{ + if (command.query().empty()) + return arrow::Status::Invalid("CommandStatementQuery: query must not be empty"); + return SQLSet{command.query(), {}, {}}; +} + +static CommandSelectorResult commandStatementUpdate(const arrow::flight::protocol::sql::CommandStatementUpdate & command) +{ + if (command.query().empty()) + return arrow::Status::Invalid("CommandStatementUpdate: query must not be empty"); + return SQLSet{command.query(), {}, {}}; +} + +static CommandSelectorResult commandStatementIngest(const arrow::flight::protocol::sql::CommandStatementIngest & command) +{ + using CommandStatementIngest = arrow::flight::protocol::sql::CommandStatementIngest; + + if (command.has_table_definition_options()) + { + const auto & options = command.table_definition_options(); + if (options.if_not_exist() != CommandStatementIngest::TableDefinitionOptions::TABLE_NOT_EXIST_OPTION_FAIL || + options.if_exists() != CommandStatementIngest::TableDefinitionOptions::TABLE_EXISTS_OPTION_APPEND) + { + return arrow::Status::NotImplemented("Only appending to existing tables is supported (TABLE_NOT_EXIST_OPTION_FAIL + TABLE_EXISTS_OPTION_APPEND)"); + } + } + + if (command.has_catalog()) + return arrow::Status::NotImplemented("Catalogs are not supported."); + + if (command.temporary()) + return arrow::Status::NotImplemented("Implicit temporary tables are not supported."); + + std::string schema_string; + if (command.has_schema()) + { + if (!isValidIdentifier(command.schema())) + return arrow::Status::Invalid("Invalid schema name: ", command.schema()); + schema_string = backQuoteIfNeed(command.schema()) + "."; + } + + if (!isValidIdentifier(command.table())) + return arrow::Status::Invalid("Invalid table name: ", command.table()); + + return SQLSet{"INSERT INTO " + schema_string + backQuoteIfNeed(command.table()) + " FORMAT Arrow", {}, {}}; +} + +static std::optional commandSelectorImpl(const google::protobuf::Any & any_msg, bool schema_only) +{ + if (any_msg.Is()) + { + arrow::flight::protocol::sql::CommandGetSqlInfo command; + if (!any_msg.UnpackTo(&command)) + return arrow::Status::SerializationError("Deserialization of sql::CommandGetSqlInfo failed."); + return commandGetSqlInfo(command, schema_only); + } + else if (any_msg.Is()) + { + return arrow::Status::NotImplemented("sql::CommandGetCrossReference is not supported"); + } + else if (any_msg.Is()) + { + return commandGetCatalogs(); + } + else if (any_msg.Is()) + { + arrow::flight::protocol::sql::CommandGetDbSchemas command; + if (!any_msg.UnpackTo(&command)) + return arrow::Status::SerializationError("Deserialization of sql::CommandGetDbSchemas failed."); + return commandGetDbSchemas(command); + } + else if (any_msg.Is()) + { + return arrow::Status::NotImplemented("sql::CommandGetExportedKeys is not supported"); + } + else if (any_msg.Is()) + { + return arrow::Status::NotImplemented("sql::CommandGetImportedKeys is not supported"); + } + else if (any_msg.Is()) + { + arrow::flight::protocol::sql::CommandGetPrimaryKeys command; + if (!any_msg.UnpackTo(&command)) + return arrow::Status::SerializationError("Deserialization of sql::CommandGetPrimaryKeys failed."); + return commandGetPrimaryKeys(command); + } + else if (any_msg.Is()) + { + arrow::flight::protocol::sql::CommandGetTables command; + if (!any_msg.UnpackTo(&command)) + return arrow::Status::SerializationError("Deserialization of sql::CommandGetTables failed."); + return commandGetTables(command); + } + else if (any_msg.Is()) + { + return commandGetTableTypes(); + } + else if (any_msg.Is()) + { + arrow::flight::protocol::sql::CommandStatementQuery command; + if (!any_msg.UnpackTo(&command)) + return arrow::Status::SerializationError("Deserialization of sql::CommandStatementQuery failed."); + return commandStatementQuery(command); + } + else if (any_msg.Is()) + { + arrow::flight::protocol::sql::CommandStatementUpdate command; + if (!any_msg.UnpackTo(&command)) + return arrow::Status::SerializationError("Deserialization of sql::CommandStatementUpdate failed."); + return commandStatementUpdate(command); + } + else if (any_msg.Is()) + { + using CommandStatementIngest = arrow::flight::protocol::sql::CommandStatementIngest; + CommandStatementIngest command; + if (!any_msg.UnpackTo(&command)) + return arrow::Status::SerializationError("Deserialization of sql::CommandStatementIngest failed."); + return commandStatementIngest(command); + } + else + { + if (isArrowFlightSql(any_msg)) + return arrow::Status::NotImplemented("Command is not implemented: ", any_msg.ShortDebugString()); + } + + return std::nullopt; +} + +CommandSelectorResult commandSelector(const std::string & cmd, bool schema_only) +{ + if (cmd.empty()) + return arrow::Status::Invalid("Empty command."); + if (cmd.size() > static_cast(std::numeric_limits::max())) + return arrow::Status::Invalid("Command payload is too large."); + if (google::protobuf::Any any_msg; any_msg.ParseFromArray(cmd.data(), static_cast(cmd.size()))) + if (auto result = commandSelectorImpl(any_msg, schema_only)) + return *result; + return SQLSet{cmd, {}, {}}; +} + +} + +} diff --git a/src/Server/ArrowFlight/commandSelector.h b/src/Server/ArrowFlight/commandSelector.h new file mode 100644 index 000000000000..d376dfe7d252 --- /dev/null +++ b/src/Server/ArrowFlight/commandSelector.h @@ -0,0 +1,84 @@ +#pragma once + +#include + +#include +#include +#include + + +namespace DB +{ + +namespace ArrowFlight +{ + +using SchemaModifier = std::function>(std::shared_ptr)>; +using BlockModifier = std::function; + +struct SQLSet +{ + std::string sql; + SchemaModifier schema_modifier; + BlockModifier block_modifier; +}; + +struct CommandSelectorResult : private std::variant>> +{ + // NOLINTNEXTLINE(google-explicit-constructor) + CommandSelectorResult(const SQLSet & sql_set) : std::variant>>(sql_set) {} + // NOLINTNEXTLINE(google-explicit-constructor) + CommandSelectorResult(const arrow::Result> & table) : std::variant>>(table) {} + // NOLINTNEXTLINE(google-explicit-constructor) + CommandSelectorResult(const arrow::Status & status) : std::variant>>(status) {} + // NOLINTNEXTLINE(google-explicit-constructor) + CommandSelectorResult(std::shared_ptr table) : std::variant>>(table) {} + + SQLSet * getSQLSet() + { + return std::get_if(this); + } + + arrow::Result> * getTable() + { + return std::get_if>>(this); + } +}; + +inline bool isArrowFlightSql(const google::protobuf::Any & any_msg) +{ + const auto & type_url = any_msg.type_url(); + const auto slash_pos = type_url.find_last_of('/'); + const auto type_name = (slash_pos == std::string::npos) ? type_url : type_url.substr(slash_pos + 1); + + return type_name.starts_with("arrow.flight.protocol.sql."); +} + +inline bool cmdIsArrowFlightSql(const std::string & cmd) +{ + if (cmd.size() > static_cast(std::numeric_limits::max())) + return false; + google::protobuf::Any any_msg; + if (!any_msg.ParseFromArray(cmd.data(), static_cast(cmd.size()))) + return false; + return isArrowFlightSql(any_msg); +} + +inline bool flightDescriptorIsArrowFlightSqlCommand(const arrow::flight::FlightDescriptor & descriptor) +{ + if (descriptor.type != arrow::flight::FlightDescriptor::CMD) + return false; + return cmdIsArrowFlightSql(descriptor.cmd); +} + +/// commandSelector accepts arrow flight sql command buffer and produces either resulting arrow::Table +/// (and if schema_only == true then table can be empty - only schema is requested) or set of sql query - which will be executed, +/// and, if resulting table requires modification, possible schema_modifier and block_modifier - they should consistently +/// manipulate schema and blocks to produce compatible results. In case command is invalid, an error should be returned +/// in arrow::Result> variant - in practice, return arrow::Status::<...> - +/// CommandSelectorResult has an implicit constructor for arrow::Status. +CommandSelectorResult commandSelector(const std::string & cmd, bool schema_only = false); + +} + +} diff --git a/src/Server/grpc_protos/CMakeLists.txt b/src/Server/grpc_protos/CMakeLists.txt index 902c0b8b1ff4..4eeceecaf3ac 100644 --- a/src/Server/grpc_protos/CMakeLists.txt +++ b/src/Server/grpc_protos/CMakeLists.txt @@ -1,4 +1,5 @@ -PROTOBUF_GENERATE_GRPC_CPP(clickhouse_grpc_proto_sources clickhouse_grpc_proto_headers clickhouse_grpc.proto) +file(GLOB PROTO_FILES CONFIGURE_DEPENDS "*.proto") +PROTOBUF_GENERATE_GRPC_CPP(clickhouse_grpc_proto_sources clickhouse_grpc_proto_headers ${PROTO_FILES}) add_library(clickhouse_grpc_protos ${clickhouse_grpc_proto_headers} ${clickhouse_grpc_proto_sources}) target_include_directories(clickhouse_grpc_protos SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/tests/integration/test_arrowflight_interface/flight_sql_client.py b/tests/integration/test_arrowflight_interface/flight_sql_client.py new file mode 100644 index 000000000000..330530c715f0 --- /dev/null +++ b/tests/integration/test_arrowflight_interface/flight_sql_client.py @@ -0,0 +1,486 @@ +""" +Minimal Flight SQL client using pyarrow.flight and protobuf. + +Replaces the `flightsql-dbapi` package (which conflicts with pyiceberg's +SQLAlchemy>=2 requirement) with just the subset the integration tests need: + - FlightSQLClient (execute, execute_update, do_get, metadata commands) + - flight_descriptor helper + - Flight SQL protobuf message classes +""" + +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple + +import pyarrow as pa +from google.protobuf import any_pb2 +from pyarrow import flight + +# --------------------------------------------------------------------------- +# Protobuf definitions generated from the Arrow Flight SQL .proto +# +# The serialized FileDescriptorProto below was built programmatically from +# the Arrow Flight SQL protocol spec. It defines: +# CommandGetSqlInfo, CommandStatementQuery, CommandStatementUpdate, +# DoPutUpdateResult, CommandGetCatalogs, CommandGetDbSchemas, +# CommandGetTables, CommandGetTableTypes, CommandGetPrimaryKeys, +# CommandStatementIngest (with nested TableDefinitionOptions + enums). +# --------------------------------------------------------------------------- +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 as _descriptor_pb2 + +_sym_db = _symbol_database.Default() + +_DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x19flightsql/flightsql.proto\x12\x19arrow.flight.protocol.sql' + b'\x1a google/protobuf/descriptor.proto' + # CommandGetSqlInfo + b'"&\n\x11CommandGetSqlInfo\x12\x0c\n\x04info\x18\x01 \x03(\r:\x03\xc0>\x01' + # CommandStatementQuery + b'"[\n\x15CommandStatementQuery\x12\r\n\x05query\x18\x01 \x01(\t' + b'\x12\x1b\n\x0etransaction_id\x18\x02 \x01(\x0cH\x00\x88\x01\x01:\x03\xc0>\x01' + b'B\x11\n\x0f_transaction_id' + # CommandStatementUpdate + b'"\\\n\x16CommandStatementUpdate\x12\r\n\x05query\x18\x01 \x01(\t' + b'\x12\x1b\n\x0etransaction_id\x18\x02 \x01(\x0cH\x00\x88\x01\x01:\x03\xc0>\x01' + b'B\x11\n\x0f_transaction_id' + # DoPutUpdateResult + b'".\n\x11DoPutUpdateResult\x12\x14\n\x0crecord_count\x18\x01 \x01(\x03:\x03\xc0>\x01' + # CommandGetCatalogs (no fields) + b'"\x19\n\x12CommandGetCatalogs:\x03\xc0>\x01' + # CommandGetDbSchemas + b'"\x80\x01\n\x13CommandGetDbSchemas' + b'\x12\x14\n\x07catalog\x18\x01 \x01(\tH\x00\x88\x01\x01' + b'\x12%\n\x18db_schema_filter_pattern\x18\x02 \x01(\tH\x01\x88\x01\x01' + b':\x03\xc0>\x01B\n\n\x08_catalogB\x1b\n\x19_db_schema_filter_pattern' + # CommandGetTables + b'"\xf0\x01\n\x10CommandGetTables' + b'\x12\x14\n\x07catalog\x18\x01 \x01(\tH\x00\x88\x01\x01' + b'\x12%\n\x18db_schema_filter_pattern\x18\x02 \x01(\tH\x01\x88\x01\x01' + b'\x12&\n\x19table_name_filter_pattern\x18\x03 \x01(\tH\x02\x88\x01\x01' + b'\x12\x13\n\x0btable_types\x18\x04 \x03(\t' + b'\x12\x16\n\x0einclude_schema\x18\x05 \x01(\x08' + b':\x03\xc0>\x01B\n\n\x08_catalogB\x1b\n\x19_db_schema_filter_pattern' + b'B\x1c\n\x1a_table_name_filter_pattern' + # CommandGetTableTypes (no fields) + b'"\x1b\n\x14CommandGetTableTypes:\x03\xc0>\x01' + # CommandGetPrimaryKeys + b'"s\n\x15CommandGetPrimaryKeys' + b'\x12\x14\n\x07catalog\x18\x01 \x01(\tH\x00\x88\x01\x01' + b'\x12\x16\n\tdb_schema\x18\x02 \x01(\tH\x01\x88\x01\x01' + b'\x12\r\n\x05table\x18\x03 \x01(\t' + b':\x03\xc0>\x01B\n\n\x08_catalogB\x0c\n\n_db_schema' + # CommandStatementIngest (with nested TableDefinitionOptions) + b'"\xb6\x07\n\x16CommandStatementIngest' + b'\x12j\n\x18table_definition_options\x18\x01 \x01(\x0b2H' + b'.arrow.flight.protocol.sql.CommandStatementIngest.TableDefinitionOptions' + b'\x12\r\n\x05table\x18\x02 \x01(\t' + b'\x12\x13\n\x06schema\x18\x03 \x01(\tH\x00\x88\x01\x01' + b'\x12\x14\n\x07catalog\x18\x04 \x01(\tH\x01\x88\x01\x01' + b'\x12\x11\n\ttemporary\x18\x05 \x01(\x08' + b'\x12\x1b\n\x0etransaction_id\x18\x06 \x01(\x0cH\x02\x88\x01\x01' + b'\x12O\n\x07options\x18\x07 \x03(\x0b2>' + b'.arrow.flight.protocol.sql.CommandStatementIngest.OptionsEntry' + b'\x1a\x99\x04\n\x16TableDefinitionOptions' + b'\x12r\n\x0cif_not_exist\x18\x01 \x01(\x0e2\\' + b'.arrow.flight.protocol.sql.CommandStatementIngest' + b'.TableDefinitionOptions.TableNotExistOption' + b'\x12m\n\tif_exists\x18\x02 \x01(\x0e2Z' + b'.arrow.flight.protocol.sql.CommandStatementIngest' + b'.TableDefinitionOptions.TableExistsOption' + b'"\x81\x01\n\x13TableNotExistOption' + b'\x12&\n"TABLE_NOT_EXIST_OPTION_UNSPECIFIED\x10\x00' + b'\x12!\n\x1dTABLE_NOT_EXIST_OPTION_CREATE\x10\x01' + b'\x12\x1f\n\x1bTABLE_NOT_EXIST_OPTION_FAIL\x10\x02' + b'"\x97\x01\n\x11TableExistsOption' + b'\x12#\n\x1fTABLE_EXISTS_OPTION_UNSPECIFIED\x10\x00' + b'\x12\x1c\n\x18TABLE_EXISTS_OPTION_FAIL\x10\x01' + b'\x12\x1e\n\x1aTABLE_EXISTS_OPTION_APPEND\x10\x02' + b'\x12\x1f\n\x1bTABLE_EXISTS_OPTION_REPLACE\x10\x03' + b'\x1a*\n\x0cOptionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t' + b'\x12\r\n\x05value\x18\x02 \x01(\t' + b':\x03\xc0>\x01B\t\n\x07_schemaB\n\n\x08_catalogB\x11\n\x0f_transaction_id' + # Extension: experimental (field 1000 on MessageOptions) + b':6\n\x0cexperimental\x12\x1f.google.protobuf.MessageOptions\x18\xe8\x07 \x01(\x08' + # File-level options + b'B\x00b\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(_DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(_DESCRIPTOR, 'flightsql.flightsql_pb2', _globals) + +# Expose message classes at module level. +CommandGetSqlInfo = _globals['CommandGetSqlInfo'] +CommandStatementQuery = _globals['CommandStatementQuery'] +CommandStatementUpdate = _globals['CommandStatementUpdate'] +DoPutUpdateResult = _globals['DoPutUpdateResult'] +CommandGetCatalogs = _globals['CommandGetCatalogs'] +CommandGetDbSchemas = _globals['CommandGetDbSchemas'] +CommandGetTables = _globals['CommandGetTables'] +CommandGetTableTypes = _globals['CommandGetTableTypes'] +CommandGetPrimaryKeys = _globals['CommandGetPrimaryKeys'] +CommandStatementIngest = _globals['CommandStatementIngest'] + +# --------------------------------------------------------------------------- +# Flight.proto action messages (arrow.flight.protocol package) +# +# Defines: SessionOptionValue, SetSessionOptionsRequest/Result, +# GetSessionOptionsRequest/Result, CancelFlightInfoResult, +# CancelStatus enum. +# --------------------------------------------------------------------------- +_FLIGHT_ACTIONS_DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x1bflight/flight_actions.proto\x12\x15arrow.flight.protocol' + # SessionOptionValue (oneof: string, bool, sfixed64, double, StringListValue) + b'"\xfc\x01\n\x12SessionOptionValue' + b'\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00' + b'\x12\x14\n\nbool_value\x18\x02 \x01(\x08H\x00' + b'\x12\x15\n\x0bint64_value\x18\x03 \x01(\x10H\x00' + b'\x12\x16\n\x0cdouble_value\x18\x04 \x01(\x01H\x00' + b'\x12V\n\x11string_list_value\x18\x05 \x01(\x0b\x32\x39' + b'.arrow.flight.protocol.SessionOptionValue.StringListValueH\x00' + b'\x1a!\n\x0fStringListValue\x12\x0e\n\x06values\x18\x01 \x03(\t' + b'B\x0e\n\x0coption_value' + # SetSessionOptionsRequest (map) + b'"\xda\x01\n\x18SetSessionOptionsRequest' + b'\x12\\\n\x0fsession_options\x18\x01 \x03(\x0b\x32\x43' + b'.arrow.flight.protocol.SetSessionOptionsRequest.SessionOptionsEntry' + b'\x1a`\n\x13SessionOptionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t' + b'\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32)' + b'.arrow.flight.protocol.SessionOptionValue:\x028\x01' + # SetSessionOptionsResult (map, ErrorValue enum) + b'"\xec\x02\n\x17SetSessionOptionsResult' + b'\x12J\n\x06\x65rrors\x18\x01 \x03(\x0b\x32:' + b'.arrow.flight.protocol.SetSessionOptionsResult.ErrorsEntry' + b'\x1aQ\n\x05\x45rror\x12H\n\x05value\x18\x01 \x01(\x0e\x32\x39' + b'.arrow.flight.protocol.SetSessionOptionsResult.ErrorValue' + b'\x1a\x63\n\x0b\x45rrorsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t' + b'\x12\x43\n\x05value\x18\x02 \x01(\x0b\x32\x34' + b'.arrow.flight.protocol.SetSessionOptionsResult.Error:\x028\x01' + b'"M\n\nErrorValue\x12\x0f\n\x0bUNSPECIFIED\x10\x00' + b'\x12\x10\n\x0cINVALID_NAME\x10\x01\x12\x11\n\rINVALID_VALUE\x10\x02' + b'\x12\t\n\x05\x45RROR\x10\x03' + # GetSessionOptionsRequest (empty) + b'"\x1a\n\x18GetSessionOptionsRequest' + # GetSessionOptionsResult (map) + b'"\xd8\x01\n\x17GetSessionOptionsResult' + b'\x12[\n\x0fsession_options\x18\x01 \x03(\x0b\x32\x42' + b'.arrow.flight.protocol.GetSessionOptionsResult.SessionOptionsEntry' + b'\x1a`\n\x13SessionOptionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t' + b'\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32)' + b'.arrow.flight.protocol.SessionOptionValue:\x028\x01' + # CancelFlightInfoRequest (bytes info -- serialized FlightInfo) + b'"\'\n\x17\x43\x61ncelFlightInfoRequest\x12\x0c\n\x04info\x18\x01 \x01(\x0c' + # CancelFlightInfoResult (CancelStatus enum) + b'"M\n\x16\x43\x61ncelFlightInfoResult' + b'\x12\x33\n\x06status\x18\x01 \x01(\x0e\x32#' + b'.arrow.flight.protocol.CancelStatus' + # PollInfo (bytes info, bytes flight_descriptor, optional double progress) + b'"W\n\x08PollInfo' + b'\x12\x0c\n\x04info\x18\x01 \x01(\x0c' + b'\x12\x19\n\x11\x66light_descriptor\x18\x02 \x01(\x0c' + b'\x12\x15\n\x08progress\x18\x03 \x01(\x01H\x00\x88\x01\x01' + b'B\x0b\n\t_progress' + # CancelStatus enum + b'*\x8b\x01\n\x0c\x43\x61ncelStatus' + b'\x12\x1d\n\x19\x43\x41NCEL_STATUS_UNSPECIFIED\x10\x00' + b'\x12\x1b\n\x17\x43\x41NCEL_STATUS_CANCELLED\x10\x01' + b'\x12\x1c\n\x18\x43\x41NCEL_STATUS_CANCELLING\x10\x02' + b'\x12!\n\x1d\x43\x41NCEL_STATUS_NOT_CANCELLABLE\x10\x03' + b'b\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(_FLIGHT_ACTIONS_DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + _FLIGHT_ACTIONS_DESCRIPTOR, 'flight.flight_actions_pb2', _globals) + +SessionOptionValue = _globals['SessionOptionValue'] +SetSessionOptionsRequest = _globals['SetSessionOptionsRequest'] +SetSessionOptionsResult = _globals['SetSessionOptionsResult'] +GetSessionOptionsRequest = _globals['GetSessionOptionsRequest'] +GetSessionOptionsResult = _globals['GetSessionOptionsResult'] +CancelFlightInfoRequest = _globals['CancelFlightInfoRequest'] +CancelFlightInfoResult = _globals['CancelFlightInfoResult'] +PollInfo = _globals['PollInfo'] +CancelStatus = _globals['CancelStatus'] + +# --------------------------------------------------------------------------- +# Thin wrapper over PollInfo to deserialize returned data +# --------------------------------------------------------------------------- +class PollResult: + """Wraps raw PollInfo proto with pyarrow-deserialized accessors.""" + + def __init__(self, proto): + self._proto = proto + + @property + def info(self): + """Deserialized FlightInfo, or None.""" + if self._proto.info: + return flight.FlightInfo.deserialize(self._proto.info) + return None + + @property + def flight_descriptor(self): + """Deserialized FlightDescriptor for next poll, or None if query is complete.""" + if self._proto.flight_descriptor: + return flight.FlightDescriptor.deserialize(self._proto.flight_descriptor) + return None + + @property + def progress(self): + """Query progress 0.0-1.0, or None if unknown.""" + if self._proto.HasField('progress'): + return self._proto.progress + return None + + @property + def info_bytes(self): + """Raw serialized FlightInfo bytes for CancelFlightInfo.""" + return self._proto.info + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def flight_descriptor(command: Any) -> flight.FlightDescriptor: + """Pack a protobuf command into a FlightDescriptor.""" + wrapper = any_pb2.Any() + wrapper.Pack(command) + return flight.FlightDescriptor.for_command(wrapper.SerializeToString()) + + +def _create_flight_client( + host: str = "localhost", + port: int = 443, + insecure: Optional[bool] = None, + disable_server_verification: Optional[bool] = None, + metadata: Optional[Dict[str, str]] = None, + **kwargs: Any, +) -> Tuple[flight.FlightClient, List[Tuple[bytes, bytes]]]: + protocol = "tls" + if insecure: + protocol = "tcp" + elif disable_server_verification: + kwargs["disable_server_verification"] = True + + url = f"grpc+{protocol}://{host}:{port}" + client = flight.FlightClient(url, **kwargs) + + headers: List[Tuple[bytes, bytes]] = [] + for k, v in (metadata or {}).items(): + headers.append((k.encode("utf-8"), v.encode("utf-8"))) + + return client, headers + + +# --------------------------------------------------------------------------- +# Client +# --------------------------------------------------------------------------- + +class FlightSQLClient: + """ + Thin Flight SQL wrapper around pyarrow.flight.FlightClient. + + Implements only the subset used by the ClickHouse integration tests: + execute, execute_update, do_get, and Flight SQL metadata commands, + plus access to the underlying client. + """ + + def __init__(self, *args, features: Optional[Dict[str, str]] = None, **kwargs): + self._host = kwargs.get('host', args[0] if args else 'localhost') + self._port = kwargs.get('port', args[1] if len(args) > 1 else 443) + self._insecure = kwargs.get('insecure', False) + self._tls_root_certs = kwargs.get('tls_root_certs') + self._override_hostname = kwargs.get('override_hostname') + + client, headers = _create_flight_client(*args, **kwargs) + self.client = client + self.headers = headers + self.features = features or {} + + def _flight_call_options(self): + headers = list(OrderedDict(self.headers).items()) + return flight.FlightCallOptions(headers=headers) + + def execute(self, query: str) -> flight.FlightInfo: + """Execute a query and return FlightInfo for result retrieval.""" + cmd = CommandStatementQuery(query=query) + options = self._flight_call_options() + return self.client.get_flight_info(flight_descriptor(cmd), options) + + def execute_update(self, query: str) -> int: + """Execute a DDL/DML statement and return the affected row count.""" + cmd = CommandStatementUpdate(query=query) + desc = flight_descriptor(cmd) + options = self._flight_call_options() + writer, reader = self.client.do_put( + desc, pa.schema([]), options + ) + result = reader.read() + writer.close() + + if result is None: + return 0 + update_result = DoPutUpdateResult() + update_result.ParseFromString(result.to_pybytes()) + return update_result.record_count + + def do_get(self, ticket) -> flight.FlightStreamReader: + """Retrieve Arrow data for a given ticket.""" + options = self._flight_call_options() + return self.client.do_get(ticket, options) + + # ----------------------------------------------------------------------- + # Flight SQL metadata commands + # ----------------------------------------------------------------------- + + def get_sql_info(self, info_ids: Optional[List[int]] = None) -> flight.FlightInfo: + """Retrieve server metadata via CommandGetSqlInfo.""" + cmd = CommandGetSqlInfo() + if info_ids: + for info_id in info_ids: + cmd.info.append(info_id) + options = self._flight_call_options() + return self.client.get_flight_info(flight_descriptor(cmd), options) + + def get_catalogs(self) -> flight.FlightInfo: + """Retrieve catalog list via CommandGetCatalogs.""" + cmd = CommandGetCatalogs() + options = self._flight_call_options() + return self.client.get_flight_info(flight_descriptor(cmd), options) + + def get_db_schemas(self, db_schema_filter_pattern: Optional[str] = None) -> flight.FlightInfo: + """Retrieve database schemas via CommandGetDbSchemas.""" + cmd = CommandGetDbSchemas() + if db_schema_filter_pattern is not None: + cmd.db_schema_filter_pattern = db_schema_filter_pattern + options = self._flight_call_options() + return self.client.get_flight_info(flight_descriptor(cmd), options) + + def get_tables( + self, + db_schema_filter_pattern: Optional[str] = None, + table_name_filter_pattern: Optional[str] = None, + table_types: Optional[List[str]] = None, + include_schema: bool = False, + ) -> flight.FlightInfo: + """Retrieve table list via CommandGetTables.""" + cmd = CommandGetTables() + if db_schema_filter_pattern is not None: + cmd.db_schema_filter_pattern = db_schema_filter_pattern + if table_name_filter_pattern is not None: + cmd.table_name_filter_pattern = table_name_filter_pattern + if table_types: + for t in table_types: + cmd.table_types.append(t) + cmd.include_schema = include_schema + options = self._flight_call_options() + return self.client.get_flight_info(flight_descriptor(cmd), options) + + def get_table_types(self) -> flight.FlightInfo: + """Retrieve table engine types via CommandGetTableTypes.""" + cmd = CommandGetTableTypes() + options = self._flight_call_options() + return self.client.get_flight_info(flight_descriptor(cmd), options) + + def get_primary_keys(self, table: str, db_schema: Optional[str] = None) -> flight.FlightInfo: + """Retrieve primary keys for a table via CommandGetPrimaryKeys.""" + cmd = CommandGetPrimaryKeys() + cmd.table = table + if db_schema is not None: + cmd.db_schema = db_schema + options = self._flight_call_options() + return self.client.get_flight_info(flight_descriptor(cmd), options) + + def get_schema(self, query: str) -> flight.SchemaResult: + """Retrieve query result schema without executing via GetSchema.""" + cmd = CommandStatementQuery(query=query) + options = self._flight_call_options() + return self.client.get_schema(flight_descriptor(cmd), options) + + def set_session_options(self, options: Dict[str, Any]) -> SetSessionOptionsResult: + """Set session options via the SetSessionOptions action. + + Use None as a value to reset a setting to its default (sends a valueless + SessionOptionValue, which the server interprets as SET setting = DEFAULT). + """ + req = SetSessionOptionsRequest() + for key, value in options.items(): + opt_val = SessionOptionValue() + if value is None: + pass # leave opt_val empty — server treats this as "reset to default" + elif isinstance(value, str): + opt_val.string_value = value + elif isinstance(value, bool): + opt_val.bool_value = value + elif isinstance(value, int): + opt_val.int64_value = value + elif isinstance(value, float): + opt_val.double_value = value + elif isinstance(value, list): + opt_val.string_list_value.values.extend(value) + req.session_options[key].CopyFrom(opt_val) + + action = flight.Action("SetSessionOptions", req.SerializeToString()) + results = list(self.client.do_action(action, self._flight_call_options())) + result = SetSessionOptionsResult() + result.ParseFromString(results[0].body.to_pybytes()) + return result + + def get_session_options(self) -> GetSessionOptionsResult: + """Get current session options via the GetSessionOptions action.""" + req = GetSessionOptionsRequest() + action = flight.Action("GetSessionOptions", req.SerializeToString()) + results = list(self.client.do_action(action, self._flight_call_options())) + result = GetSessionOptionsResult() + result.ParseFromString(results[0].body.to_pybytes()) + return result + + def cancel_flight_info(self, flight_info_bytes: bytes) -> CancelFlightInfoResult: + """Cancel a query via the CancelFlightInfo action.""" + req = CancelFlightInfoRequest() + req.info = flight_info_bytes + action = flight.Action("CancelFlightInfo", req.SerializeToString()) + results = list(self.client.do_action(action, self._flight_call_options())) + result = CancelFlightInfoResult() + result.ParseFromString(results[0].body.to_pybytes()) + return result + + def poll_flight_info(self, descriptor) -> PollResult: + """Call PollFlightInfo RPC via grpc. + + pyarrow.flight.FlightClient does not expose PollFlightInfo despite + it being part of the Flight protocol and implemented in the + underlying C++ library. We work around this by making the gRPC + call directly, reusing the same auth headers. + """ + import grpc as _grpc + + target = f'{self._host}:{self._port}' + if self._insecure: + channel = _grpc.insecure_channel(target) + else: + credentials = _grpc.ssl_channel_credentials(root_certificates=self._tls_root_certs) + options = [] + if self._override_hostname: + options.append(('grpc.ssl_target_name_override', self._override_hostname)) + channel = _grpc.secure_channel(target, credentials, options=options) + + try: + call = channel.unary_unary( + '/arrow.flight.protocol.FlightService/PollFlightInfo', + request_serializer=lambda x: x, + response_deserializer=lambda x: x, + ) + metadata = list(OrderedDict(self.headers).items()) + raw_response = call(descriptor.serialize(), metadata=metadata, timeout=30) + finally: + channel.close() + + result = PollInfo() + result.ParseFromString(raw_response) + return PollResult(result) diff --git a/tests/integration/test_arrowflight_interface/test.py b/tests/integration/test_arrowflight_interface/test.py index 634ade912d13..327e04d7fb6f 100644 --- a/tests/integration/test_arrowflight_interface/test.py +++ b/tests/integration/test_arrowflight_interface/test.py @@ -192,7 +192,7 @@ def test_doput_cmd_insert_invalid_format(): writer.close() assert False, "Expected to fail because of a wrong format but succeeded" except flight.FlightServerError as e: - assert "Invalid format, only 'Arrow' format is supported" in str(e) + assert "Invalid format (JSON), only 'Arrow' format is supported" in str(e) # INSERT queries without the FORMAT clause are considered invalid. @@ -621,13 +621,13 @@ def test_invalid_user(): client = flight.FlightClient( f"grpc+tls://{node.ip_address}:8888", disable_server_verification=True ) - token = client.authenticate_basic_token(b"invalid", b"password") - options = flight.FlightCallOptions(headers=[token]) - ticket = flight.Ticket(b"SELECT * FROM mytable") try: + token = client.authenticate_basic_token(b"invalid", b"password") + options = flight.FlightCallOptions(headers=[token]) + ticket = flight.Ticket(b"SELECT * FROM mytable") client.do_get(ticket, options) assert False, "Expected authentication failure (login and password are not correct) but succeeded" - except flight.FlightServerError as e: + except flight.FlightUnauthenticatedError as e: assert ( "Authentication failed: password is incorrect, or there is no user with such name" in str(e) diff --git a/tests/integration/test_arrowflight_interface/test_sql_server.py b/tests/integration/test_arrowflight_interface/test_sql_server.py new file mode 100644 index 000000000000..aceb9d70ebf5 --- /dev/null +++ b/tests/integration/test_arrowflight_interface/test_sql_server.py @@ -0,0 +1,871 @@ +# coding: utf-8 + +import os +import pytest +import pyarrow as pa +import pyarrow.flight as flight +import random +import string +from .flight_sql_client import ( + FlightSQLClient, + flight_descriptor, + CommandStatementUpdate, + DoPutUpdateResult, + CancelStatus, + SetSessionOptionsResult, + CommandStatementQuery, + CommandStatementIngest, +) + + +from helpers.cluster import ClickHouseCluster, get_docker_compose_path +from helpers.test_tools import TSV + + +SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) +DOCKER_COMPOSE_PATH = get_docker_compose_path() + +cluster = ClickHouseCluster(__file__) +node = cluster.add_instance( + "node", + main_configs=[ + "configs/flight_port.xml", + ], +) + +session_id = ''.join(random.choices(string.ascii_letters + string.digits, k=16)) + +def get_client(): + return FlightSQLClient( + host=node.ip_address, + port=8888, + insecure=True, + disable_server_verification=True, + metadata={'x-clickhouse-session-id': session_id}, + features={'metadata-reflection': 'true'}, # makes the client emit metadata retrieval commands upon connection + ) + + +@pytest.fixture(scope="module", autouse=True) +def start_cluster(): + try: + cluster.start() + node.wait_until_port_is_ready(8888, timeout=10) + yield cluster + finally: + cluster.shutdown() + + +@pytest.fixture(autouse=True) +def cleanup_after_test(): + try: + yield + finally: + node.query("DROP TABLE IF EXISTS mytable, map_test, large_test, bulk_test SYNC") + + +def test_select(): + client = get_client() + flight_info = client.execute("SELECT 1, 'hello', 3.14") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + tsv_output = table.to_pandas().to_csv(sep='\t', index=False, header=False) + + assert tsv_output == "1\thello\t3.14\n" + +def test_create_table_and_insert(): + client = get_client() + + # Create table + client.execute_update("CREATE TABLE mytable (id UInt32, name String, value Float64) ENGINE = Memory") + + # Insert data + client.execute_update("INSERT INTO mytable VALUES (1, 'test', 42.5), (2, 'hello', 3.14)") + + # Query and verify + flight_info = client.execute("SELECT * FROM mytable ORDER BY id") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + pandas_df = table.to_pandas() + tsv_output = pandas_df.to_csv(sep='\t', index=False, header=False) + + expected = "1\ttest\t42.5\n2\thello\t3.14\n" + assert tsv_output == expected + + +def test_map_data_type(): + client = get_client() + + # Test Map data type handling + client.execute_update("CREATE TABLE map_test (id UInt32, data Map(String, UInt64)) ENGINE = Memory") + client.execute_update("INSERT INTO map_test VALUES (1, {'key1': 100, 'key2': 200})") + + flight_info = client.execute("SELECT * FROM map_test") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + # Verify we can read the map data without errors + assert table.num_rows == 1 + assert table.num_columns == 2 + + # Check that the map column has the correct Arrow type + map_column = table.column(1) + assert isinstance(map_column.type, pa.MapType) + + +def test_error_handling(): + client = get_client() + + # Test invalid SQL + with pytest.raises(flight.FlightServerError): + client.execute("INVALID SQL SYNTAX") + + # Test querying non-existent table + with pytest.raises(flight.FlightServerError): + client.execute("SELECT * FROM non_existent_table") + + +def test_large_result_set(): + client = get_client() + + # Create table with many rows to test streaming + client.execute_update("CREATE TABLE large_test (id UInt32, value String) ENGINE = Memory") + client.execute_update("INSERT INTO large_test SELECT number, toString(number) FROM numbers(10000)") + + flight_info = client.execute("SELECT COUNT(*) FROM large_test") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + count_value = table.column(0)[0].as_py() + assert count_value == 10000 + + +def test_streaming_insert(): + """ + Test bulk data insertion via Arrow Flight SQL. + + Note: This test uses a workaround due to Arrow Flight SQL version limitations. + Arrow Flight SQL v11 lacks bulk ingestion functionality (CommandStatementIngest), + which was introduced in v12. ClickHouse supports a non-standard approach using + CommandStatementUpdate, but this is not supported by the flightsql-dbapi module. + + This implementation uses a mix of the underlying Flight API with the Flight SQL + protobuf definitions. When upgrading to Arrow Flight SQL v12+, this test should + be replaced with the standard CommandStatementIngest approach. + """ + client = get_client() + + client.execute_update("CREATE TABLE bulk_test (id UInt32, str String) ENGINE = Memory") + + cmd = CommandStatementUpdate(query="INSERT INTO bulk_test FORMAT Arrow") + descriptor = flight_descriptor(cmd) + schema = pa.schema([ + ("id", pa.uint32()), + ("str", pa.string()), + ]) + + writer, reader = client.client.do_put(descriptor, schema, client._flight_call_options()) + + for n in range(1000): + batch = pa.record_batch([ + pa.array([n*1, n*2, n*3, n*4, n*5, n*6, n*7], type=pa.uint32()), + pa.array([str(n*1), str(n*2), str(n*3), str(n*4), str(n*5), str(n*6), str(n*7)], type=pa.string()), + ], schema=schema) + writer.write_batch(batch) + + writer.done_writing() + + result = reader.read() + + assert result is not None + update_result = DoPutUpdateResult() + update_result.ParseFromString(result.to_pybytes()) + assert update_result.record_count == 7000 + + +# +# Flight SQL Metadata Commands +# + +def test_get_sql_info(): + """CommandGetSqlInfo returns server metadata.""" + client = get_client() + flight_info = client.get_sql_info() + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + # Should have info_name (uint32) and value (dense_union) columns + assert table.num_columns == 2 + assert table.column_names == ["info_name", "value"] + assert table.num_rows > 0 + + # Convert to dict for easier assertions + info = {} + for i in range(table.num_rows): + info[table.column("info_name")[i].as_py()] = table.column("value")[i].as_py() + + # FLIGHT_SQL_SERVER_NAME = 0 + assert info[0] == "ClickHouse" + # FLIGHT_SQL_SERVER_READ_ONLY = 3 + assert info[3] == False + # FLIGHT_SQL_SERVER_SQL = 4 + assert info[4] == True + # FLIGHT_SQL_SERVER_SUBSTRAIT = 5 + assert info[5] == False + # FLIGHT_SQL_SERVER_CANCEL = 9 + assert info[9] == True + + +def test_get_sql_info_filtered(): + """CommandGetSqlInfo with specific info IDs returns only requested items.""" + client = get_client() + # Request only FLIGHT_SQL_SERVER_NAME (0) and FLIGHT_SQL_SERVER_VERSION (1) + flight_info = client.get_sql_info(info_ids=[0, 1]) + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 2 + + +def test_get_catalogs(): + """CommandGetCatalogs returns empty result (ClickHouse has no catalogs).""" + client = get_client() + flight_info = client.get_catalogs() + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 0 + assert "catalog_name" in table.column_names + + +def test_get_db_schemas(): + """CommandGetDbSchemas returns database list.""" + client = get_client() + flight_info = client.get_db_schemas() + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + schemas = [table.column("db_schema_name")[i].as_py() for i in range(table.num_rows)] + assert "default" in schemas + assert "system" in schemas + + +def test_get_db_schemas_with_filter(): + """CommandGetDbSchemas with filter pattern.""" + client = get_client() + flight_info = client.get_db_schemas(db_schema_filter_pattern="def%") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + schemas = [table.column("db_schema_name")[i].as_py() for i in range(table.num_rows)] + assert "default" in schemas + assert "system" not in schemas + + +def test_get_tables(): + """CommandGetTables returns table list.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32) ENGINE = Memory") + + flight_info = client.get_tables( + db_schema_filter_pattern="default", + table_name_filter_pattern="mytable" + ) + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 1 + assert table.column("table_name")[0].as_py() == "mytable" + + +def test_get_tables_with_schema(): + """CommandGetTables with include_schema=True returns Arrow schema bytes.""" + client = get_client() + client.execute_update( + "CREATE TABLE mytable (id UInt32, name String, value Float64) ENGINE = Memory" + ) + + flight_info = client.get_tables( + db_schema_filter_pattern="default", + table_name_filter_pattern="mytable", + include_schema=True + ) + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 1 + assert "table_schema" in table.column_names + # table_schema column should contain serialized Arrow schema bytes + schema_bytes = table.column("table_schema")[0].as_py() + assert len(schema_bytes) > 0 + + +def test_get_table_types(): + """CommandGetTableTypes returns engine types.""" + client = get_client() + flight_info = client.get_table_types() + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + types = [table.column("table_type")[i].as_py() for i in range(table.num_rows)] + assert "REMOTE TABLE" in types + assert "VIEW" in types + assert "UNKNOWN TABLE TYPE" not in types, \ + "Some engine(s) in system.table_engines are not mapped in engine_to_type (commandSelector.cpp)" + + +def test_get_primary_keys(): + """CommandGetPrimaryKeys returns primary key columns.""" + client = get_client() + client.execute_update( + "CREATE TABLE mytable (id UInt32, name String, value Float64) ENGINE = MergeTree ORDER BY (id, name)" + ) + + flight_info = client.get_primary_keys(table="mytable", db_schema="default") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 2 + columns = [table.column("column_name")[i].as_py() for i in range(table.num_rows)] + assert columns == ["id", "name"] + # key_seq should be 1-based sequential + seqs = [table.column("key_seq")[i].as_py() for i in range(table.num_rows)] + assert seqs == [1, 2] + + +# +# DoAction Tests +# + +def test_set_session_options(): + """SetSessionOptions sets ClickHouse settings.""" + client = get_client() + result = client.set_session_options({"max_threads": "4"}) + assert len(result.errors) == 0 + + +def test_set_session_options_invalid_setting(): + """SetSessionOptions with unknown setting returns INVALID_NAME error.""" + client = get_client() + result = client.set_session_options({"nonexistent_setting_xyz": "value"}) + assert "nonexistent_setting_xyz" in result.errors + assert result.errors["nonexistent_setting_xyz"].value == SetSessionOptionsResult.INVALID_NAME + + +def test_get_session_options(): + """GetSessionOptions returns current settings.""" + client = get_client() + result = client.get_session_options() + assert "max_threads" in result.session_options + assert result.session_options["max_threads"].string_value != "" + + +def _query_setting(client, name): + """Read the current value of a setting via SQL query.""" + flight_info = client.execute(f"SELECT value FROM system.settings WHERE name = '{name}'") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + return table.column(0)[0].as_py() + + +def test_set_session_options_persistence(): + """SetSessionOptions changes persist and are visible in subsequent queries.""" + client = get_client() + + # Reset max_threads to default first (previous tests may have modified it) + result = client.set_session_options({"max_threads": None}) + assert len(result.errors) == 0 + + # Read the default value + default_value = _query_setting(client, "max_threads") + + # Pick a value that differs from the default + new_value = "7" if default_value != "7" else "5" + + # Set the setting via SetSessionOptions + result = client.set_session_options({"max_threads": new_value}) + assert len(result.errors) == 0 + + # Verify the setting persists via SQL query + assert _query_setting(client, "max_threads") == new_value + + # Verify via GetSessionOptions as well + options = client.get_session_options() + assert options.session_options["max_threads"].string_value == new_value + + # Reset to default + result = client.set_session_options({"max_threads": None}) + assert len(result.errors) == 0 + + # Verify the setting was restored to the original default + assert _query_setting(client, "max_threads") == default_value + + +def test_cancel_flight_info(): + client = get_client() + + descriptor = flight.FlightDescriptor.for_command( + b"SELECT sleepEachRow(0.5) FROM numbers(100)" + ) + poll_result = client.poll_flight_info(descriptor) + assert poll_result.info is not None + + result = client.cancel_flight_info(poll_result.info_bytes) + assert result.status == CancelStatus.Value('CANCEL_STATUS_CANCELLED') + + +def test_unsupported_action(): + """Unsupported action type returns error.""" + client = get_client() + action = flight.Action("SomeUnsupportedAction", b"") + with pytest.raises(pa.lib.ArrowNotImplementedError, match="not supported"): + list(client.client.do_action(action, client._flight_call_options())) + + +# +# PollFlightInfo Tests +# + +def test_poll_flight_info_basic(): + """PollFlightInfo streams results incrementally.""" + client = get_client() + + client.execute_update("CREATE TABLE mytable (id UInt32) ENGINE = Memory") + client.execute_update("INSERT INTO mytable SELECT number FROM numbers(100)") + + descriptor = flight.FlightDescriptor.for_command(b"SELECT * FROM mytable") + + poll_result = client.poll_flight_info(descriptor) + assert poll_result.info is not None + + # Collect all FlightInfo bytes by polling until no next descriptor + all_infos = [poll_result.info] + while poll_result.flight_descriptor is not None: + poll_result = client.poll_flight_info(poll_result.flight_descriptor) + all_infos.append(poll_result.info) + + # Read all data via tickets + total_rows = 0 + for endpoint in all_infos[-1].endpoints: + reader = client.do_get(endpoint.ticket) + table = reader.read_all() + total_rows += table.num_rows + + assert total_rows == 100 + + +def test_poll_flight_info_with_path_descriptor(): + """PollFlightInfo works with PATH descriptor (table name).""" + client = get_client() + + client.execute_update("CREATE TABLE mytable (id UInt32, name String) ENGINE = Memory") + client.execute_update("INSERT INTO mytable VALUES (1, 'a'), (2, 'b')") + + descriptor = flight.FlightDescriptor.for_path("mytable") + + poll_result = client.poll_flight_info(descriptor) + assert poll_result.info is not None + assert poll_result.info.total_records >= 0 + + # Cancel the running query so cleanup can drop the table + client.cancel_flight_info(poll_result.info_bytes) + + +# +# GetSchema Tests +# + +def test_get_schema(): + """GetSchema returns schema without executing the query.""" + client = get_client() + + client.execute_update( + "CREATE TABLE mytable (id UInt32, name String, value Float64) ENGINE = Memory" + ) + + # GetSchema via Flight SQL CommandStatementQuery + schema_result = client.get_schema("SELECT * FROM mytable") + schema = schema_result.schema + + assert len(schema) == 3 + assert schema.field("id").type == pa.uint32() + assert schema.field("name").type == pa.string() + assert schema.field("value").type == pa.float64() + + +def test_get_schema_path_descriptor(): + """GetSchema works with PATH descriptor.""" + client = get_client() + + client.execute_update("CREATE TABLE mytable (id Int64, name String) ENGINE = Memory") + + descriptor = flight.FlightDescriptor.for_path("mytable") + options = client._flight_call_options() + + schema_result = client.client.get_schema(descriptor, options) + schema = schema_result.schema + + assert schema.field("id").type == pa.int64() + assert schema.field("name").type == pa.string() + + +# +# Data Type Coverage +# + +def test_array_data_type(): + """Array type round-trip.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, arr Array(UInt32)) ENGINE = Memory") + client.execute_update("INSERT INTO mytable VALUES (1, [10, 20, 30])") + + flight_info = client.execute("SELECT * FROM mytable") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 1 + assert isinstance(table.column("arr").type, pa.ListType) + assert table.column("arr")[0].as_py() == [10, 20, 30] + + +def test_tuple_data_type(): + """Tuple type round-trip.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, t Tuple(String, UInt32)) ENGINE = Memory") + client.execute_update("INSERT INTO mytable VALUES (1, ('hello', 42))") + + flight_info = client.execute("SELECT * FROM mytable") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 1 + # Tuple maps to Arrow struct + assert isinstance(table.column("t").type, pa.StructType) + + +def test_nullable_data_type(): + """Nullable type round-trip.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, val Nullable(String)) ENGINE = Memory") + client.execute_update("INSERT INTO mytable VALUES (1, 'hello'), (2, NULL)") + + flight_info = client.execute("SELECT * FROM mytable ORDER BY id") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 2 + assert table.column("val")[0].as_py() == "hello" + assert table.column("val")[1].as_py() is None + + +def test_datetime_data_types(): + """DateTime and DateTime64 round-trip.""" + client = get_client() + client.execute_update( + "CREATE TABLE mytable (id UInt32, dt DateTime, dt64 DateTime64(3)) ENGINE = Memory" + ) + client.execute_update( + "INSERT INTO mytable VALUES (1, '2024-01-15 10:30:00', '2024-01-15 10:30:00.123')" + ) + + flight_info = client.execute("SELECT * FROM mytable") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 1 + # DateTime maps to uint32 (unix timestamp) + assert table.column("dt").type == pa.uint32() + assert table.column("dt")[0].as_py() == 1705314600 + # DateTime64 maps to Arrow timestamp + assert pa.types.is_timestamp(table.column("dt64").type) + +def test_decimal_data_type(): + """Decimal type round-trip.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, val Decimal(18, 4)) ENGINE = Memory") + client.execute_update("INSERT INTO mytable VALUES (1, 123.4567)") + + flight_info = client.execute("SELECT * FROM mytable") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 1 + assert pa.types.is_decimal(table.column("val").type) + + +def test_uuid_data_type(): + """UUID type round-trip.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, uid UUID) ENGINE = Memory") + client.execute_update( + "INSERT INTO mytable VALUES (1, '550e8400-e29b-41d4-a716-446655440000')" + ) + + flight_info = client.execute("SELECT * FROM mytable") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 1 + + +def test_lowcardinality_data_type(): + """LowCardinality type round-trip.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, val LowCardinality(String)) ENGINE = Memory") + client.execute_update("INSERT INTO mytable VALUES (1, 'aaa'), (2, 'bbb'), (3, 'aaa')") + + flight_info = client.execute("SELECT * FROM mytable ORDER BY id") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 3 + vals = [table.column("val")[i].as_py() for i in range(3)] + assert vals == ["aaa", "bbb", "aaa"] + + +def test_enum_data_type(): + """Enum type round-trip.""" + client = get_client() + client.execute_update( + "CREATE TABLE mytable (id UInt32, status Enum8('ok' = 1, 'error' = 2)) ENGINE = Memory" + ) + client.execute_update("INSERT INTO mytable VALUES (1, 'ok'), (2, 'error')") + + flight_info = client.execute("SELECT * FROM mytable ORDER BY id") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 2 + + +# +# Session Management +# + +def test_session_state_persistence(): + """Session ID preserves state across requests (e.g., temp tables, settings).""" + client = get_client() # already uses x-clickhouse-session-id + + client.execute_update("SET max_threads = 2") + + flight_info = client.execute("SELECT value FROM system.settings WHERE name = 'max_threads'") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.column(0)[0].as_py() == "2" + + +def test_different_sessions_are_independent(): + """Different session IDs have independent state.""" + import random, string + session_id_1 = ''.join(random.choices(string.ascii_letters, k=16)) + session_id_2 = ''.join(random.choices(string.ascii_letters, k=16)) + + client1 = FlightSQLClient( + host=node.ip_address, port=8888, insecure=True, + disable_server_verification=True, + metadata={'x-clickhouse-session-id': session_id_1}, + ) + client2 = FlightSQLClient( + host=node.ip_address, port=8888, insecure=True, + disable_server_verification=True, + metadata={'x-clickhouse-session-id': session_id_2}, + ) + + client1.execute_update("SET max_threads = 3") + + # client2 should still see the default + flight_info = client2.execute("SELECT value FROM system.settings WHERE name = 'max_threads'") + reader = client2.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + # Should NOT be "3" since it's a different session + assert table.column(0)[0].as_py() != "3" + + +# +# Bearer Token Authentication +# + +def test_bearer_token_reuse(): + """After Basic auth, the returned Bearer token can authenticate subsequent requests.""" + client = flight.FlightClient(f"grpc://{node.ip_address}:8888") + + # First request with Basic auth returns a Bearer token + token_pair = client.authenticate_basic_token("default", "") + options = flight.FlightCallOptions(headers=[token_pair]) + + # Use the Bearer token for a query + ticket = flight.Ticket(b"SELECT 1") + reader = client.do_get(ticket, options) + table = reader.read_all() + assert table.column(0)[0].as_py() == 1 + + +# +# Edge Cases +# + +def test_empty_result_set(): + """Query returning zero rows produces valid empty table.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, name String) ENGINE = Memory") + + flight_info = client.execute("SELECT * FROM mytable") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 0 + assert table.num_columns == 2 + assert table.schema.field("id").type == pa.uint32() + assert table.schema.field("name").type == pa.string() + + +def test_empty_query_in_command_statement(): + """CommandStatementQuery with empty query returns error.""" + client = get_client() + # Construct a CommandStatementQuery with empty query string + cmd = CommandStatementQuery(query="") + desc = flight_descriptor(cmd) + options = client._flight_call_options() + + with pytest.raises(pa.lib.ArrowInvalid, match="query must not be empty"): + client.client.get_flight_info(desc, options) + + +def test_multiple_statements_via_execute_update(): + """Multiple DDL/DML via execute_update in sequence.""" + client = get_client() + + client.execute_update("CREATE TABLE mytable (id UInt32, val String) ENGINE = Memory") + + for i in range(10): + client.execute_update(f"INSERT INTO mytable VALUES ({i}, 'row_{i}')") + + flight_info = client.execute("SELECT count() FROM mytable") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.column(0)[0].as_py() == 10 + + +def test_special_characters_in_data(): + """Data with special characters (unicode, quotes, newlines) round-trips correctly.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, val String) ENGINE = Memory") + client.execute_update( + r"INSERT INTO mytable VALUES (1, 'hello\nworld'), (2, 'it''s \"quoted\"'), (3, '日本語テスト')" + ) + + flight_info = client.execute("SELECT * FROM mytable ORDER BY id") + reader = client.do_get(flight_info.endpoints[0].ticket) + table = reader.read_all() + + assert table.num_rows == 3 + assert table.column("val")[2].as_py() == '日本語テスト' + + +# +# CommandStatementIngest +# + +def test_statement_ingest(): + """CommandStatementIngest inserts data into existing table.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32, name String) ENGINE = Memory") + + cmd = CommandStatementIngest() + cmd.table = "mytable" + cmd.table_definition_options.if_not_exist = ( + CommandStatementIngest.TableDefinitionOptions.TABLE_NOT_EXIST_OPTION_FAIL + ) + cmd.table_definition_options.if_exists = ( + CommandStatementIngest.TableDefinitionOptions.TABLE_EXISTS_OPTION_APPEND + ) + + descriptor = flight_descriptor(cmd) + schema = pa.schema([("id", pa.uint32()), ("name", pa.string())]) + + writer, reader = client.client.do_put(descriptor, schema, client._flight_call_options()) + batch = pa.record_batch( + [pa.array([1, 2, 3], type=pa.uint32()), pa.array(["a", "b", "c"], type=pa.string())], + schema=schema, + ) + writer.write_batch(batch) + writer.done_writing() + result = reader.read() + writer.close() + + update_result = DoPutUpdateResult() + update_result.ParseFromString(result.to_pybytes()) + assert update_result.record_count == 3 + + # Verify data + flight_info = client.execute("SELECT * FROM mytable ORDER BY id") + r = client.do_get(flight_info.endpoints[0].ticket) + t = r.read_all() + assert t.num_rows == 3 + + +def test_statement_ingest_with_schema(): + """CommandStatementIngest with database schema prefix.""" + client = get_client() + client.execute_update("CREATE TABLE default.mytable (id UInt32) ENGINE = Memory") + + cmd = CommandStatementIngest() + cmd.table = "mytable" + cmd.schema = "default" + cmd.table_definition_options.if_not_exist = ( + CommandStatementIngest.TableDefinitionOptions.TABLE_NOT_EXIST_OPTION_FAIL + ) + cmd.table_definition_options.if_exists = ( + CommandStatementIngest.TableDefinitionOptions.TABLE_EXISTS_OPTION_APPEND + ) + + descriptor = flight_descriptor(cmd) + schema = pa.schema([("id", pa.uint32())]) + writer, reader = client.client.do_put(descriptor, schema, client._flight_call_options()) + batch = pa.record_batch([pa.array([1], type=pa.uint32())], schema=schema) + writer.write_batch(batch) + writer.done_writing() + reader.read() + writer.close() + + +def test_statement_ingest_catalog_not_supported(): + """CommandStatementIngest with catalog returns NotImplemented.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32) ENGINE = Memory") + + cmd = CommandStatementIngest() + cmd.table = "mytable" + cmd.catalog = "some_catalog" + + descriptor = flight_descriptor(cmd) + schema = pa.schema([("id", pa.uint32())]) + + with pytest.raises(pa.lib.ArrowNotImplementedError, match="Catalogs are not supported"): + writer, reader = client.client.do_put(descriptor, schema, client._flight_call_options()) + batch = pa.record_batch([pa.array([1], type=pa.uint32())], schema=schema) + writer.write_batch(batch) + writer.close() + + +def test_statement_ingest_temporary_not_supported(): + """CommandStatementIngest with temporary=True returns NotImplemented.""" + client = get_client() + client.execute_update("CREATE TABLE mytable (id UInt32) ENGINE = Memory") + + cmd = CommandStatementIngest() + cmd.table = "mytable" + cmd.temporary = True + + descriptor = flight_descriptor(cmd) + schema = pa.schema([("id", pa.uint32())]) + + with pytest.raises(pa.lib.ArrowNotImplementedError, match="Implicit temporary tables are not supported"): + writer, reader = client.client.do_put(descriptor, schema, client._flight_call_options()) + batch = pa.record_batch([pa.array([1], type=pa.uint32())], schema=schema) + writer.write_batch(batch) + writer.close() diff --git a/tests/queries/0_stateless/04070_arrow_complex_types.reference b/tests/queries/0_stateless/04070_arrow_complex_types.reference new file mode 100644 index 000000000000..dc0f706348eb --- /dev/null +++ b/tests/queries/0_stateless/04070_arrow_complex_types.reference @@ -0,0 +1,43 @@ +=== Arrays === +[] [] [] +[42] ['single'] [NULL] +[1,2,3] ['hello','world'] [1.5,NULL,3.5] +[0,0,0,0,0] ['a','b','c','d','e'] [NULL,NULL,NULL,NULL,NULL] +=== Nested Arrays === +1 [[1,2],[3]] [['a',NULL],['b']] +2 [[],[4,5,6]] [[NULL,NULL],[]] +3 [[]] [[]] +=== Maps === +{} {} +{'single':42} {'only':'one'} +{'key1':1,'key2':2} {'a':'x','b':'y'} +=== Tuples === +(0,'',0) ((0,0),'') +(1,'hello',3.14) ((10,20),'nested') +(42,'world',-1.5) ((100,200),'deep') +=== Named Tuples === +(1,'abc',1.1) +(2,'def',2.2) +(3,'ghi',3.3) +=== Variants === +Variant export OK, size: positive +=== Nested Combinations === +[] {} ([],{}) +[(42,'only')] {'single':[100]} ([0],{'':'empty'}) +[(1,'a'),(2,'b')] {'x':[1,2,3],'y':[4,5]} ([10,20],{'k':'v'}) +=== Multi-block === +0 [0,0,0] {'k0':0} (0,'0') +1 [1,2,3] {'k1':1} (1,'1') +2 [2,4,6] {'k2':2} (2,'2') +3 [3,6,9] {'k3':3} (3,'3') +4 [4,8,12] {'k4':4} (4,'4') +5 [5,10,15] {'k5':5} (5,'5') +6 [6,12,18] {'k6':6} (6,'6') +7 [7,14,21] {'k7':7} (7,'7') +8 [8,16,24] {'k8':8} (8,'8') +9 [9,18,27] {'k9':9} (9,'9') +=== Map with Nullable values === +1 {'a':1,'b':NULL} +2 {} +3 {'c':NULL,'d':NULL} +4 {'e':42} diff --git a/tests/queries/0_stateless/04070_arrow_complex_types.sh b/tests/queries/0_stateless/04070_arrow_complex_types.sh new file mode 100755 index 000000000000..693d33a5017b --- /dev/null +++ b/tests/queries/0_stateless/04070_arrow_complex_types.sh @@ -0,0 +1,202 @@ +#!/usr/bin/env bash +# Tags: no-fasttest + +# Test Arrow format serialization/deserialization for complex types: +# Arrays (with edge cases), Maps, Tuples, Variants, +# and nested combinations thereof. + +set -e + +CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CUR_DIR"/../shell_config.sh + +# ---------------------------------------------- +# 1. Arrays: basic roundtrip with edge cases +# ---------------------------------------------- + +echo "=== Arrays ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_arrays_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_arrays_dst" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_arrays_src (a Array(UInt32), b Array(String), c Array(Nullable(Float64))) ENGINE = Memory" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_arrays_dst AS arrow_test_arrays_src" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_arrays_src VALUES ([1, 2, 3], ['hello', 'world'], [1.5, NULL, 3.5]), ([], [], []), ([42], ['single'], [NULL]), ([0, 0, 0, 0, 0], ['a', 'b', 'c', 'd', 'e'], [NULL, NULL, NULL, NULL, NULL])" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_arrays_src ORDER BY length(a) FORMAT Arrow" | \ + ${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_arrays_dst FORMAT Arrow" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_arrays_dst ORDER BY length(a)" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_arrays_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_arrays_dst" + +# ---------------------------------------------- +# 2. Nested arrays +# ---------------------------------------------- + +echo "=== Nested Arrays ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_nested_arr_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_nested_arr_dst" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_nested_arr_src (id UInt32, a Array(Array(UInt32)), b Array(Array(Nullable(String)))) ENGINE = Memory" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_nested_arr_dst AS arrow_test_nested_arr_src" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_nested_arr_src VALUES (1, [[1,2],[3]], [['a',NULL],['b']]), (2, [[], [4,5,6]], [[NULL, NULL],[]]), (3, [[]], [[]])" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_nested_arr_src ORDER BY id FORMAT Arrow" | \ + ${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_nested_arr_dst FORMAT Arrow" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_nested_arr_dst ORDER BY id" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_nested_arr_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_nested_arr_dst" + +# ---------------------------------------------- +# 3. Maps +# ---------------------------------------------- + +echo "=== Maps ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_maps_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_maps_dst" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_maps_src (a Map(String, UInt64), b Map(String, String)) ENGINE = Memory" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_maps_dst AS arrow_test_maps_src" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_maps_src VALUES ({'key1': 1, 'key2': 2}, {'a': 'x', 'b': 'y'}), ({}, {}), ({'single': 42}, {'only': 'one'})" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_maps_src ORDER BY length(a) FORMAT Arrow" | \ + ${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_maps_dst FORMAT Arrow" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_maps_dst ORDER BY length(a)" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_maps_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_maps_dst" + +# ---------------------------------------------- +# 4. Tuples +# ---------------------------------------------- + +echo "=== Tuples ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_tuples_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_tuples_dst" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_tuples_src (a Tuple(UInt32, String, Float64), b Tuple(Tuple(UInt32, UInt32), String)) ENGINE = Memory" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_tuples_dst AS arrow_test_tuples_src" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_tuples_src VALUES ((1, 'hello', 3.14), ((10, 20), 'nested')), ((0, '', 0), ((0, 0), '')), ((42, 'world', -1.5), ((100, 200), 'deep'))" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_tuples_src ORDER BY a.1 FORMAT Arrow" | \ + ${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_tuples_dst FORMAT Arrow" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_tuples_dst ORDER BY a.1" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_tuples_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_tuples_dst" + +# ---------------------------------------------- +# 5. Named Tuples +# ---------------------------------------------- + +echo "=== Named Tuples ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_named_tuples_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_named_tuples_dst" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_named_tuples_src (t Tuple(x UInt32, y String, z Float64)) ENGINE = Memory" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_named_tuples_dst AS arrow_test_named_tuples_src" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_named_tuples_src VALUES ((1, 'abc', 1.1)), ((2, 'def', 2.2)), ((3, 'ghi', 3.3))" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_named_tuples_src ORDER BY t.x FORMAT Arrow" | \ + ${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_named_tuples_dst FORMAT Arrow" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_named_tuples_dst ORDER BY t.x" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_named_tuples_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_named_tuples_dst" + +# ---------------------------------------------- +# 6. Variant (export only - no Arrow import for DenseUnion) +# ---------------------------------------------- + +echo "=== Variants ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_variants" +${CLICKHOUSE_CLIENT} --query="SET allow_suspicious_variant_types = 1; CREATE TABLE arrow_test_variants (v Variant(UInt32, String, Float64)) ENGINE = Memory" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_variants VALUES (42::UInt32), ('hello'), (3.14::Float64), (NULL), (0::UInt32), (''), (NULL)" + +# Variant -> Arrow DenseUnion export should succeed without errors. +# We verify the output is non-empty valid Arrow data. +ARROW_SIZE=$(${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_variants FORMAT Arrow" | wc -c) +if [ "$ARROW_SIZE" -gt 0 ]; then + echo "Variant export OK, size: positive" +else + echo "Variant export FAILED: empty output" +fi + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_variants" + +# ---------------------------------------------- +# 7. Nested combinations: Array(Tuple), Map(String, Array), Tuple(Array, Map) +# ---------------------------------------------- + +echo "=== Nested Combinations ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_combo_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_combo_dst" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_combo_src (arr_of_tuples Array(Tuple(UInt32, String)), map_of_arrays Map(String, Array(UInt32)), tuple_with_map Tuple(Array(UInt32), Map(String, String))) ENGINE = Memory" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_combo_dst AS arrow_test_combo_src" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_combo_src VALUES ([(1,'a'),(2,'b')], {'x':[1,2,3],'y':[4,5]}, ([10,20], {'k':'v'})), ([], {}, ([], {})), ([(42,'only')], {'single':[100]}, ([0], {'':'empty'}))" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_combo_src ORDER BY length(arr_of_tuples) FORMAT Arrow" | \ + ${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_combo_dst FORMAT Arrow" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_combo_dst ORDER BY length(arr_of_tuples)" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_combo_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_combo_dst" + +# ---------------------------------------------- +# 8. Multi-block roundtrip with complex types +# ---------------------------------------------- + +echo "=== Multi-block ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_multiblock_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_multiblock_dst" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_multiblock_src (id UInt32, arr Array(UInt32), m Map(String, UInt64), t Tuple(UInt32, String)) ENGINE = Memory" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_multiblock_dst AS arrow_test_multiblock_src" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_multiblock_src SELECT number, [number, number*2, number*3], map('k' || toString(number), number), (number, toString(number)) FROM numbers(10)" + +${CLICKHOUSE_CLIENT} --max_block_size=2 --query="SELECT * FROM arrow_test_multiblock_src ORDER BY id FORMAT Arrow" | \ + ${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_multiblock_dst FORMAT Arrow" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_multiblock_dst ORDER BY id" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_multiblock_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_multiblock_dst" + +# ---------------------------------------------- +# 9. Map with Nullable values +# ---------------------------------------------- + +echo "=== Map with Nullable values ===" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_map_nullable_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE IF EXISTS arrow_test_map_nullable_dst" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_map_nullable_src (id UInt32, m Map(String, Nullable(UInt64))) ENGINE = Memory" +${CLICKHOUSE_CLIENT} --query="CREATE TABLE arrow_test_map_nullable_dst AS arrow_test_map_nullable_src" + +${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_map_nullable_src VALUES (1, {'a': 1, 'b': NULL}), (2, {}), (3, {'c': NULL, 'd': NULL}), (4, {'e': 42})" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_map_nullable_src ORDER BY id FORMAT Arrow" | \ + ${CLICKHOUSE_CLIENT} --query="INSERT INTO arrow_test_map_nullable_dst FORMAT Arrow" + +${CLICKHOUSE_CLIENT} --query="SELECT * FROM arrow_test_map_nullable_dst ORDER BY id" + +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_map_nullable_src" +${CLICKHOUSE_CLIENT} --query="DROP TABLE arrow_test_map_nullable_dst" From f3f529fd4f48faa3b2c3d4400eae1a5865421bd3 Mon Sep 17 00:00:00 2001 From: Andrey Zvonov Date: Fri, 15 May 2026 21:08:46 +0200 Subject: [PATCH 2/4] Resolve conflicts in cherry-pick of #91170 --- ci/jobs/scripts/check_style/check_cpp.sh | 46 +- src/Core/SettingsChangesHistory.cpp | 20 - .../Formats/Impl/CHColumnToArrowColumn.cpp | 229 +-- src/Server/ArrowFlight/ArrowFlightServer.cpp | 3 +- src/Server/ArrowFlightHandler.cpp | 1412 ----------------- 5 files changed, 16 insertions(+), 1694 deletions(-) delete mode 100644 src/Server/ArrowFlightHandler.cpp diff --git a/ci/jobs/scripts/check_style/check_cpp.sh b/ci/jobs/scripts/check_style/check_cpp.sh index 621ba855f373..cd327a8a97b2 100755 --- a/ci/jobs/scripts/check_style/check_cpp.sh +++ b/ci/jobs/scripts/check_style/check_cpp.sh @@ -340,7 +340,7 @@ ls -1d $ROOT_PATH/contrib/*-cmake | xargs -I@ find @ -name 'CMakeLists.txt' -or # Wrong spelling of abbreviations, e.g. SQL is right, Sql is wrong. XMLHttpRequest is very wrong. find $ROOT_PATH/{src,base,programs,utils} -name '*.h' -or -name '*.cpp' | grep -vP $EXCLUDE | - xargs grep -P 'Sql|Html|Xml|Cpu|Tcp|Udp|Http|Db|Json|Yaml' | grep -v -P 'RabbitMQ|Azure|Aws|aws|Avro|IO/S3|ai::JsonValue|IcebergWrites|arrow::flight|TcpExtListenOverflows' && + xargs grep -P 'Sql|Html|Xml|Cpu|Tcp|Udp|Http|Db|Json|Yaml' | grep -v -P 'RabbitMQ|Azure|Aws|aws|Avro|IO/S3|ai::JsonValue|IcebergWrites|arrow::flight|SqlInfo|CommandGetSqlInfo|CommandGetDbSchemas|commandGetDbSchemas|ArrowFlightSql|TcpExtListenOverflows' && echo "Abbreviations such as SQL, XML, HTTP, should be in all caps. For example, SQL is right, Sql is wrong. XMLHttpRequest is very wrong." find $ROOT_PATH/{src,base,programs,utils} -name '*.h' -or -name '*.cpp' | @@ -352,50 +352,6 @@ PATTERN="allow_"; DIFF=$(comm -3 <(grep -o "\b$PATTERN\w*\b" $ROOT_PATH/src/Core/Settings.cpp | sort -u) <(grep -o -h "\b$PATTERN\w*\b" $ROOT_PATH/src/Databases/enableAllExperimentalSettings.cpp $ROOT_PATH/ci/jobs/scripts/check_style/experimental_settings_ignore.txt | sort -u)); [ -n "$DIFF" ] && echo "$DIFF" && echo "^^ Detected 'allow_*' settings that might need to be included in src/Databases/enableAllExperimentalSettings.cpp" && echo "Alternatively, consider adding an exception to ci/jobs/scripts/check_style/experimental_settings_ignore.txt" -<<<<<<< HEAD -======= -# 12a: NDEBUG and cast checks on nobase_all -{ -# A small typo can lead to debug code in release builds, see https://github.com/ClickHouse/ClickHouse/pull/47647 -xargs < "$STYLE_TMPDIR/nobase_all" grep -l -F '#ifdef NDEBUG' | \ - xargs awk '/#ifdef NDEBUG/ { inside = 1; dirty = 1 } /#endif/ { if (inside && dirty) { print "File " FILENAME " has suspicious #ifdef NDEBUG, possibly confused with #ifndef NDEBUG" }; inside = 0 } /#else/ { dirty = 0 }' - -# If a user is doing dynamic or typeid cast with a pointer, and immediately dereferencing it, it is unsafe. -xargs < "$STYLE_TMPDIR/nobase_all" rg --line-number '(dynamic|typeid)_cast<[^>]+\*>\([^\(\)]+\)->' | grep . && echo "It's suspicious when you are doing a dynamic_cast or typeid_cast with a pointer and immediately dereferencing it. Use references instead of pointers or check a pointer to nullptr." -} > "$O.12a" 2>&1 & - -# 12b: Punctuation, std::regex, and Cyrillic checks on nobase_all -{ -# Check for bad punctuation: whitespace before comma. -xargs < "$STYLE_TMPDIR/nobase_all" rg --line-number '\w ,' | grep -v 'bad punctuation is ok here' && echo "^ There is bad punctuation: whitespace before comma. You should write it like this: 'Hello, world!'" - -# Check usage of std::regex which is too bloated and slow. -xargs < "$STYLE_TMPDIR/nobase_all" grep -F --line-number 'std::regex' | grep . && echo "^ Please use re2 instead of std::regex" - -# Cyrillic characters hiding inside Latin. -grep -v StorageSystemContributors.generated.cpp "$STYLE_TMPDIR/nobase_all" | \ - xargs rg --line-number '[a-zA-Z][а-яА-ЯёЁ]|[а-яА-ЯёЁ][a-zA-Z]' && echo "^ Cyrillic characters found in unexpected place." -} > "$O.12b" 2>&1 & - -# 13: Orphaned header files -{ -join -v1 <(grep '\.h$' "$STYLE_TMPDIR/nobase_all" | sed 's:.*/::' | sort -u) <(rg --no-filename -o '[\w-]+\.h' --glob '*.cpp' --glob '*.c' --glob '*.h' --glob '*.S' $ROOT_PATH/src $ROOT_PATH/programs $ROOT_PATH/utils $ROOT_PATH/tests/lexer | sort -u) | - grep . && echo '^ Found orphan header files.' -} > "$O.13" 2>&1 & - -# 14: Abbreviation checks and error message style -{ -# Wrong spelling of abbreviations, e.g. SQL is right, Sql is wrong. XMLHttpRequest is very wrong. -xargs < "$STYLE_TMPDIR/all_excluded" rg 'Sql|Html|Xml|Cpu|Tcp|Udp|Http|Db|Json|Yaml' | grep -v -E 'RabbitMQ|Azure|Aws|aws|Avro|IO/S3|ai::JsonValue|IcebergWrites|arrow::flight|SqlInfo|CommandGetSqlInfo|CommandGetDbSchemas|commandGetDbSchemas|ArrowFlightSql|TcpExtListenOverflows' && - echo "Abbreviations such as SQL, XML, HTTP, should be in all caps. For example, SQL is right, Sql is wrong. XMLHttpRequest is very wrong." - -xargs < "$STYLE_TMPDIR/all_excluded" grep -F -i 'ErrorCodes::LOGICAL_ERROR, "Logical error:' && - echo "If an exception has LOGICAL_ERROR code, there is no need to include the text 'Logical error' in the exception message, because then the phrase 'Logical error' will be printed twice." -} > "$O.14" 2>&1 & - -# 15: magic_enum and std::format -{ ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) # Don't allow the direct inclusion of magic_enum.hpp and instead point to base/EnumReflection.h find $ROOT_PATH/{src,base,programs,utils} -name '*.cpp' -or -name '*.h' | xargs grep -l "magic_enum.hpp" | grep -v EnumReflection.h | while read -r line; do diff --git a/src/Core/SettingsChangesHistory.cpp b/src/Core/SettingsChangesHistory.cpp index 2fb63c9ef2c3..d2f4f83c11d8 100644 --- a/src/Core/SettingsChangesHistory.cpp +++ b/src/Core/SettingsChangesHistory.cpp @@ -44,26 +44,6 @@ const VersionToSettingsChangesMap & getSettingsChangesHistory() {"object_storage_cluster_join_mode", "allow", "allow", "New setting"}, {"output_format_arrow_unsupported_types_as_binary", false, true, "New setting to convert unsupported CH types to arrow binary instead of UNKNOWN_TYPE exception."}, {"output_format_parquet_unsupported_types_as_binary", false, false, "New setting to convert unsupported CH types to parquet (arrow) binary instead of UNKNOWN_TYPE exception."}, - {"asterisk_include_virtual_columns", false, false, "New setting"}, - {"max_wkb_geometry_elements", 1'000'000, 1'000'000, "New setting to limit element counts in WKB geometry parsing, preventing excessive memory allocation on malformed data."}, - {"max_rand_distribution_trials", 1'000'000'000, 1'000'000'000, "New setting to limit trial counts in random distribution functions, preventing hangs with extreme inputs."}, - {"max_rand_distribution_parameter", 1e6, 1e6, "New setting to limit shape parameters in random distribution functions, preventing hangs with extreme inputs."}, - {"optimize_truncate_order_by_after_group_by_keys", false, true, "Remove trailing ORDER BY elements once all GROUP BY keys are covered in the ORDER BY prefix."}, - {"use_statistics_for_part_pruning", false, true, "New setting to use statistics for part pruning during query execution."}, - {"distributed_index_analysis_only_on_coordinator", false, false, "New setting."}, - {"query_plan_optimize_join_order_randomize", 0, 0, "New setting to randomize join order statistics for testing."}, - {"enable_materialized_cte", false, false, "New setting"}, - {"use_strict_insert_block_limits", false, false, "New setting to use strict min and max insert bounds on inserts. When min < max, max limits take precedence."}, - {"finalize_projection_parts_synchronously", false, false, "New setting to finalize projection parts synchronously during INSERT to reduce peak memory usage."}, - {"read_in_order_use_virtual_row_per_block", false, false, "Emit virtual row after each block during read-in-order to allow more frequent source reprioritization in MergingSortedTransform."}, - {"distributed_plan_prefer_replicas_over_workers", false, false, "New setting to serialize distributed plan for replicas"}, - {"use_text_index_like_evaluation_by_dictionary_scan", true, true, "New setting"}, - {"text_index_like_min_pattern_length", 4, 4, "New setting"}, - {"text_index_like_max_postings_to_read", 50, 50, "New setting"}, - {"analyzer_inline_views", false, false, "New setting"}, - {"highlight_max_matches_per_row", 10000, 10000, "New setting to limit the number of highlight matches per row to protect against excessive memory usage."}, - {"materialize_statistics_on_insert", true, false, "Disable building statistics on INSERT by default, rely on merges instead"}, - {"enable_join_transitive_predicates", false, false, "New setting to infer transitive equi-join predicates for join order optimization."}, }); addSettingsChanges(settings_changes_history, "26.3", { diff --git a/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp b/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp index 956699c4c0e2..35bf14b7f2bd 100644 --- a/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp +++ b/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp @@ -22,13 +22,7 @@ #include #include #include -<<<<<<< HEAD -======= #include -#include -#include -#include ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) #include #include #include @@ -37,16 +31,12 @@ #include #include #include -<<<<<<< HEAD -======= #include #include #include #include #include #include -#include ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) #include #include @@ -304,55 +294,7 @@ namespace DB } } -<<<<<<< HEAD - static void fillArrowArray( -======= - static void fillArrowArrayWithUUIDColumnData( - const ColumnPtr & column, - const PaddedPODArray * null_bytemap, - const String & format_name, - arrow::ArrayBuilder * array_builder, - size_t start, - size_t end) - { - const auto * col_uuid = assert_cast *>(column.get()); - - if (array_builder->type()->id() != arrow::Type::FIXED_SIZE_BINARY) - throw Exception(ErrorCodes::LOGICAL_ERROR, - "Cannot fill arrow array with {} data for format {}", column->getName(), format_name); - - auto * fixed_builder = assert_cast(array_builder); - const auto & uuid_data = col_uuid->getData(); - - for (size_t i = start; i < end; ++i) - { - if (null_bytemap && (*null_bytemap)[i]) - { - arrow::Status status = fixed_builder->AppendNull(); - checkStatus(status, column->getName(), format_name); - continue; - } - - UUID res = uuid_data[i]; - auto * bytes = reinterpret_cast(&res); - - if constexpr (std::endian::native == std::endian::little) - { - std::reverse(bytes, bytes + 8); - std::reverse(bytes + 8, bytes + 16); - } - else - { - std::swap_ranges(bytes, bytes + 8, bytes + 8); - } - - arrow::Status status = fixed_builder->Append(reinterpret_cast(&res)); - checkStatus(status, column->getName(), format_name); - } - } - static std::shared_ptr fillArrowArray( ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) const String & column_name, ColumnPtr column, const DataTypePtr & column_type, @@ -366,7 +308,7 @@ namespace DB static std::shared_ptr getArrowType( - DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable, bool for_builder = false); + DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable); static std::shared_ptr buildArrowDenseUnionArrayWithVariantColumnData( @@ -1338,11 +1280,7 @@ namespace DB } static std::shared_ptr getArrowType( -<<<<<<< HEAD DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable) -======= - DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable, bool for_builder) ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) { if (column) { @@ -1353,13 +1291,8 @@ namespace DB if (column_type->isNullable()) { DataTypePtr nested_type = assert_cast(column_type.get())->getNestedType(); -<<<<<<< HEAD - ColumnPtr nested_column = assert_cast(column.get())->getNestedColumnPtr(); - auto arrow_type = getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable); -======= ColumnPtr nested_column = column ? assert_cast(column.get())->getNestedColumnPtr() : nullptr; - auto arrow_type = getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable, for_builder); ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) + auto arrow_type = getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable); *out_is_column_nullable = true; return arrow_type; } @@ -1408,11 +1341,7 @@ namespace DB for (size_t i = 0; i != nested_types.size(); ++i) { bool is_field_nullable = false; -<<<<<<< HEAD - auto nested_arrow_type = getArrowType(nested_types[i], tuple_column->getColumnPtr(i), nested_names[i], format_name, settings, &is_field_nullable); -======= - auto nested_arrow_type = getArrowType(nested_types[i], tuple_column ? tuple_column->getColumnPtr(i) : nullptr, nested_names[i], format_name, settings, &is_field_nullable, for_builder); ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) + auto nested_arrow_type = getArrowType(nested_types[i], tuple_column ? tuple_column->getColumnPtr(i) : nullptr, nested_names[i], format_name, settings, &is_field_nullable); nested_fields.push_back(std::make_shared(nested_names[i], nested_arrow_type, is_field_nullable)); } return arrow::struct_(nested_fields); @@ -1421,14 +1350,6 @@ namespace DB if (column_type->lowCardinality()) { auto nested_type = assert_cast(column_type.get())->getDictionaryType(); -<<<<<<< HEAD - const auto * lc_column = assert_cast(column.get()); - const auto & nested_column = lc_column->getDictionary().getNestedColumn(); - const auto & indexes_column = lc_column->getIndexesPtr(); - return arrow::dictionary( - getArrowTypeForLowCardinalityIndexes(indexes_column, settings), - getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable)); -======= if (column) { const auto * lc_column = assert_cast(column.get()); @@ -1436,17 +1357,16 @@ namespace DB const auto & indexes_column = lc_column->getIndexesPtr(); return arrow::dictionary( getArrowTypeForLowCardinalityIndexes(indexes_column, settings), - getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable, for_builder)); + getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable)); } else { auto index_arrow_type = settings.use_64_bit_indexes_for_dictionary ? (settings.use_signed_indexes_for_dictionary ? arrow::int64() : arrow::uint64()) : (settings.use_signed_indexes_for_dictionary ? arrow::int32() : arrow::uint32()); - auto arrow_type = getArrowType(nested_type, nullptr, column_name, format_name, settings, out_is_column_nullable, for_builder); + auto arrow_type = getArrowType(nested_type, nullptr, column_name, format_name, settings, out_is_column_nullable); return arrow::dictionary(index_arrow_type, arrow_type); } ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) } if (isMap(column_type)) @@ -1463,17 +1383,10 @@ namespace DB value_column = columns[1]; } -<<<<<<< HEAD - bool _is_key_nullable = false; - auto key_arrow_type = getArrowType(key_type, columns[0], column_name, format_name, settings, &_is_key_nullable); - bool is_val_nullable = false; - auto val_arrow_type = getArrowType(val_type, columns[1], column_name, format_name, settings, &is_val_nullable); -======= bool is_key_nullable = false; - auto key_arrow_type = getArrowType(key_type, key_column, column_name, format_name, settings, &is_key_nullable, for_builder); + auto key_arrow_type = getArrowType(key_type, key_column, column_name, format_name, settings, &is_key_nullable); bool is_val_nullable = false; - auto val_arrow_type = getArrowType(val_type, value_column, column_name, format_name, settings, &is_val_nullable, for_builder); ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) + auto val_arrow_type = getArrowType(val_type, value_column, column_name, format_name, settings, &is_val_nullable); return arrow::map( key_arrow_type, @@ -1504,8 +1417,6 @@ namespace DB if (isIPv4(column_type)) return arrow::uint32(); -<<<<<<< HEAD -======= if (isVariant(column_type)) { const auto * column_variant = column ? &assert_cast(*column) : nullptr; @@ -1535,8 +1446,7 @@ namespace DB variant ? variant->getName() : "variant", format_name, settings, - &is_column_nullable, - for_builder); + &is_column_nullable); std::string field_name = column_variant_type.getVariant(i)->getFamilyName(); fields.push_back(std::make_shared(field_name, arrow_type, is_column_nullable)); @@ -1550,23 +1460,7 @@ namespace DB return arrow::dense_union(fields); } - if (isInterval(column_type)) - { - const auto * interval_type = assert_cast(column_type.get()); - switch (interval_type->getKind()) - { - case IntervalKind::Kind::Nanosecond: return arrow::duration(arrow::TimeUnit::NANO); - case IntervalKind::Kind::Microsecond: return arrow::duration(arrow::TimeUnit::MICRO); - case IntervalKind::Kind::Millisecond: return arrow::duration(arrow::TimeUnit::MILLI); - case IntervalKind::Kind::Second: return arrow::duration(arrow::TimeUnit::SECOND); - default: return arrow::int64(); - } - } - - if (isUUID(column_type)) - return for_builder ? arrow::fixed_size_binary(sizeof(UUID)) : std::make_shared(); ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) if (isDate(column_type) && settings.output_date_as_uint16) return arrow::uint16(); @@ -1624,26 +1518,13 @@ namespace DB settings, &is_column_nullable); - std::shared_ptr field_metadata = nullptr; - if (column_to_field_id && column_to_field_id->contains(header_column.name)) { Int64 field_id = column_to_field_id->at(header_column.name); - field_metadata = arrow::key_value_metadata({"PARQUET:field_id"}, {std::to_string(field_id)}); - } - - // Inject our UUID metadata if it's a root UUID column - if (isUUID(removeNullable(header_column.type))) - { - auto ext_metadata = arrow::key_value_metadata( - {"ARROW:extension:name", "ARROW:extension:metadata", "PARQUET:logical_type"}, - {"arrow.uuid", "", "UUID"} - ); - field_metadata = field_metadata ? field_metadata->Merge(*ext_metadata) : ext_metadata; + auto key_value_metadata = arrow::key_value_metadata({"PARQUET:field_id"}, + {std::to_string(field_id)}); + arrow_fields.emplace_back(std::make_shared(header_column.name, arrow_type, is_column_nullable, key_value_metadata)); } - - if (field_metadata) - arrow_fields.emplace_back(std::make_shared(header_column.name, arrow_type, is_column_nullable, field_metadata)); else arrow_fields.emplace_back(std::make_shared(header_column.name, arrow_type, is_column_nullable)); } @@ -1684,13 +1565,12 @@ namespace DB column_type = recursiveRemoveLowCardinality(column_type); } - // Generate the unwrapped builder schema (safe for MakeBuilder) bool is_column_nullable = false; auto builder_type = getArrowType( - column_type, column, header_column.name, format_name, settings, &is_column_nullable, true /* for_builder */); + column_type, column, header_column.name, format_name, settings, &is_column_nullable); std::unique_ptr array_builder; - arrow::Status status = MakeBuilder(arrow::default_memory_pool(), builder_type, &array_builder); + arrow::Status status = MakeBuilder(ArrowMemoryPool::instance(), builder_type, &array_builder); checkStatus(status, column->getName(), format_name); std::shared_ptr arrow_array = fillArrowArray( @@ -1705,7 +1585,6 @@ namespace DB settings, dictionary_values); - // Zero-copy cast to the extension-rich schema (handles infinite nesting) auto target_type = schema->field(static_cast(column_i))->type(); if (!arrow_array->type()->Equals(*target_type)) arrow_array = checkResult(arrow_array->View(target_type), column->getName(), format_name); @@ -1759,45 +1638,7 @@ namespace DB { if (arrow_schema) return; -<<<<<<< HEAD - - if (!columns_num) - columns_num = header_columns.size(); - - std::vector> arrow_fields; - arrow_fields.reserve(*columns_num); - - for (size_t column_i = 0; column_i < *columns_num; ++column_i) - { - const ColumnWithTypeAndName & header_column = header_columns[column_i]; - auto column = chunk ? chunk->getColumns()[column_i] : header_column.column; - - if (!settings.low_cardinality_as_dictionary) - column = recursiveRemoveLowCardinality(column); - - bool is_column_nullable = false; - auto arrow_type = getArrowType( - header_column.type, - column, - header_column.name, - format_name, - settings, - &is_column_nullable); - if (column_to_field_id && column_to_field_id->contains(header_column.name)) - { - Int64 field_id = column_to_field_id->at(header_column.name); - auto key_value_metadata = arrow::key_value_metadata({"PARQUET:field_id"}, - {std::to_string(field_id)}); - arrow_fields.emplace_back(std::make_shared(header_column.name, arrow_type, is_column_nullable, key_value_metadata)); - } - else - arrow_fields.emplace_back(std::make_shared(header_column.name, arrow_type, is_column_nullable)); - } - - arrow_schema = std::make_shared(arrow_fields); -======= arrow_schema = calculateArrowSchema(header_columns, format_name, chunk, settings, columns_num, column_to_field_id); ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) } std::shared_ptr CHColumnToArrowColumn::getArrowSchema() const @@ -1817,51 +1658,7 @@ namespace DB const Chunk * chunk_to_initialize_schema = chunks.empty() ? nullptr : chunks.data(); initializeArrowSchema(chunk_to_initialize_schema, columns_num, column_to_field_id); -<<<<<<< HEAD - for (const auto & chunk : chunks) - { - /// For arrow::Table creation - for (size_t column_i = 0; column_i < columns_num; ++column_i) - { - const ColumnWithTypeAndName & header_column = header_columns[column_i]; - auto column = chunk.getColumns()[column_i]; - - if (!settings.low_cardinality_as_dictionary) - column = recursiveRemoveLowCardinality(column); - - std::unique_ptr array_builder; - arrow::Status status = MakeBuilder(ArrowMemoryPool::instance(), arrow_schema->field(static_cast(column_i))->type(), &array_builder); - checkStatus(status, column->getName(), format_name); - - fillArrowArray( - header_column.name, - column, - header_column.type, - nullptr, - array_builder.get(), - format_name, - 0, - column->size(), - settings, - dictionary_values); - - std::shared_ptr arrow_array; - status = array_builder->Finish(&arrow_array); - checkStatus(status, column->getName(), format_name); - - table_data.at(column_i).emplace_back(std::move(arrow_array)); - } - } - - std::vector> columns; - columns.reserve(columns_num); - for (size_t column_i = 0; column_i < columns_num; ++column_i) - columns.emplace_back(std::make_shared(table_data.at(column_i))); - - res = arrow::Table::Make(arrow_schema, columns); -======= res = calculateArrowTable(header_columns, format_name, chunks, settings, columns_num, arrow_schema, &dictionary_values); ->>>>>>> e02e0dd65eb (Merge pull request #91170 from ClickHouse/feat-arrowflight-impl) } } diff --git a/src/Server/ArrowFlight/ArrowFlightServer.cpp b/src/Server/ArrowFlight/ArrowFlightServer.cpp index b70f2bebd29e..323affa0196b 100644 --- a/src/Server/ArrowFlight/ArrowFlightServer.cpp +++ b/src/Server/ArrowFlight/ArrowFlightServer.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -984,7 +985,7 @@ arrow::Status ArrowFlightServer::DoPut( if (pipeline.pushing()) { Block header = pipeline.getHeader(); - auto input = std::make_shared(std::move(reader), header, query_context); + auto input = std::make_shared(std::move(reader), header); pipeline.complete(Pipe(std::move(input))); } else if (pipeline.pulling()) diff --git a/src/Server/ArrowFlightHandler.cpp b/src/Server/ArrowFlightHandler.cpp deleted file mode 100644 index febe6f0fb8ac..000000000000 --- a/src/Server/ArrowFlightHandler.cpp +++ /dev/null @@ -1,1412 +0,0 @@ -#include - -#if USE_ARROWFLIGHT - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace DB -{ - -namespace ErrorCodes -{ - extern const int LOGICAL_ERROR; - extern const int UNKNOWN_EXCEPTION; -} - -namespace -{ - const std::string AUTHORIZATION_HEADER = "authorization"; - const std::string AUTHORIZATION_MIDDLEWARE_NAME = "authorization_middleware"; - - class AuthMiddleware : public arrow::flight::ServerMiddleware - { - public: - explicit AuthMiddleware(const std::string & token, const std::string & username, const std::string & password) - : token_(token) - , username_(username) - , password_(password) - { - } - - static AuthMiddleware & get(const arrow::flight::ServerCallContext & context) - { - return *static_cast(context.GetMiddleware(AUTHORIZATION_MIDDLEWARE_NAME)); - } - - const std::string & username() const { return username_; } - const std::string & password() const { return password_; } - - void SendingHeaders(arrow::flight::AddCallHeaders * outgoing_headers) override - { - outgoing_headers->AddHeader(AUTHORIZATION_HEADER, "Bearer " + token_); - } - - void CallCompleted(const arrow::Status & /*status*/) override { } - - std::string name() const override { return AUTHORIZATION_MIDDLEWARE_NAME; } - - private: - const std::string token_; - const std::string username_; - const std::string password_; - }; - - class AuthMiddlewareFactory : public arrow::flight::ServerMiddlewareFactory - { - public: - arrow::Status StartCall( - const arrow::flight::CallInfo & /*info*/, - const arrow::flight::ServerCallContext & context, - std::shared_ptr * middleware) override - { - const auto & headers = context.incoming_headers(); - - auto it = headers.find(AUTHORIZATION_HEADER); - if (it == headers.end()) - return arrow::Status::IOError("Missing Authorization header"); - - auto auth_header = std::string(it->second); - - std::string token; - - const std::string prefix_basic = "Basic "; - if (auth_header.starts_with(prefix_basic)) - token = auth_header.substr(prefix_basic.size()); - - const std::string prefix_bearer = "Bearer "; - if (auth_header.starts_with(prefix_bearer)) - token = auth_header.substr(prefix_bearer.size()); - - if (token.empty()) - return arrow::Status::IOError("Expected Basic auth scheme"); - - std::string credentials = base64Decode(token, true); - auto pos = credentials.find(':'); - if (pos == std::string::npos) - return arrow::Status::IOError("Malformed credentials"); - - auto user = credentials.substr(0, pos); - auto password = credentials.substr(pos + 1); - - *middleware = std::make_unique(token, user, password); - return arrow::Status::OK(); - } - }; - - String readFile(const String & filepath) - { - Poco::FileInputStream ifs(filepath); - String buf; - Poco::StreamCopier::copyToString(ifs, buf); - return buf; - } - - arrow::flight::Location addressToArrowLocation(const Poco::Net::SocketAddress & address_to_listen, bool use_tls) - { - auto ip_to_listen = address_to_listen.host(); - auto port_to_listen = address_to_listen.port(); - - /// Function arrow::flight::Location::ForGrpc*() builds an URL so it requires IPv6 address to be enclosed in brackets - String host_component = (ip_to_listen.family() == Poco::Net::AddressFamily::IPv6) ? ("[" + ip_to_listen.toString() + "]") : ip_to_listen.toString(); - - arrow::Result parse_location_status; - if (use_tls) - parse_location_status = arrow::flight::Location::ForGrpcTls(host_component, port_to_listen); - else - parse_location_status = arrow::flight::Location::ForGrpcTcp(host_component, port_to_listen); - - if (!parse_location_status.ok()) - { - throw Exception( - ErrorCodes::UNKNOWN_EXCEPTION, - "Invalid address {} for Arrow Flight Server: {}", - address_to_listen.toString(), - parse_location_status.status().ToString()); - } - - return std::move(parse_location_status).ValueOrDie(); - } - - /// Extracts the client's address from the call context. - Poco::Net::SocketAddress getClientAddress(const arrow::flight::ServerCallContext & context) - { - /// Returns a string like ipv4:127.0.0.1:55930 or ipv6:%5B::1%5D:55930 - String uri_encoded_peer = context.peer(); - - constexpr const std::string_view ipv4_prefix = "ipv4:"; - constexpr const std::string_view ipv6_prefix = "ipv6:"; - - bool ipv4 = uri_encoded_peer.starts_with(ipv4_prefix); - bool ipv6 = uri_encoded_peer.starts_with(ipv6_prefix); - - if (!ipv4 && !ipv6) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected ipv4 or ipv6 protocol in peer address, got {}", uri_encoded_peer); - - auto prefix = ipv4 ? ipv4_prefix : ipv6_prefix; - auto family = ipv4 ? Poco::Net::AddressFamily::Family::IPv4 : Poco::Net::AddressFamily::Family::IPv6; - - uri_encoded_peer = uri_encoded_peer.substr(prefix.length()); - - String peer; - Poco::URI::decode(uri_encoded_peer, peer); - - return Poco::Net::SocketAddress{family, peer}; - } - - /// Extracts an SQL query from a flight descriptor. - /// It depends on the flight descriptor's type (PATH/CMD) and on the operation's type (DoPut/DoGet). - [[nodiscard]] arrow::Result convertDescriptorToSQL(const arrow::flight::FlightDescriptor & descriptor, bool for_put_operation) - { - switch (descriptor.type) - { - case arrow::flight::FlightDescriptor::PATH: - { - const auto & path = descriptor.path; - if (path.size() != 1) - return arrow::Status::Invalid("Flight descriptor's path should be one-component (got ", path.size(), " components)"); - if (path[0].empty()) - return arrow::Status::Invalid("Flight descriptor's path should specify the name of a table"); - const String & table_name = path[0]; - if (for_put_operation) - return "INSERT INTO " + backQuoteIfNeed(table_name) + " FORMAT Arrow"; - else - return "SELECT * FROM " + backQuoteIfNeed(table_name); - } - case arrow::flight::FlightDescriptor::CMD: - { - const auto & cmd = descriptor.cmd; - if (cmd.empty()) - return arrow::Status::Invalid("Flight descriptor's command should specify a SQL query"); - return cmd; - } - default: - return arrow::Status::TypeError("Flight descriptor has unknown type ", magic_enum::enum_name(descriptor.type)); - } - } - - [[nodiscard]] arrow::Result convertGetDescriptorToSQL(const arrow::flight::FlightDescriptor & descriptor) - { - return convertDescriptorToSQL(descriptor, /* for_put_operation = */ false); - } - - [[nodiscard]] arrow::Result convertPutDescriptorToSQL(const arrow::flight::FlightDescriptor & descriptor) - { - return convertDescriptorToSQL(descriptor, /* for_put_operation = */ true); - } - - /// For method doGet() the pipeline should have an output. - [[nodiscard]] arrow::Status checkPipelineIsPulling(const QueryPipeline & pipeline) - { - if (!pipeline.pulling()) - return arrow::Status::Invalid("Query doesn't allow pulling data, use method doPut() with this kind of query"); - return arrow::Status::OK(); - } - - /// We don't allow custom formats except "Arrow" because they can't work with ArrowFlight. - [[nodiscard]] arrow::Status checkNoCustomFormat(ASTPtr ast) - { - if (const auto * ast_with_output = dynamic_cast(ast.get())) - { - if (ast_with_output->format_ast && (getIdentifierName(ast_with_output->format_ast) != "Arrow")) - return arrow::Status::ExecutionError("Invalid format, only 'Arrow' format is supported"); - } - else if (const auto * insert = dynamic_cast(ast.get())) - { - if (!insert->format.empty() && insert->format != "Arrow") - return arrow::Status::ExecutionError("Invalid format, only 'Arrow' format is supported"); - } - return arrow::Status::OK(); - } - - using Timestamp = std::chrono::system_clock::time_point; - using Duration = std::chrono::system_clock::duration; - - Timestamp now() - { - return std::chrono::system_clock::now(); - } - - /// We use the ALREADY_EXPIRED timestamp (January 1, 1970) as the expiration time of a ticket or a poll descriptor - /// which is already expired. - const Timestamp ALREADY_EXPIRED = Timestamp{Duration{0}}; - - /// We generate tickets with this prefix. - /// Method DoGet() accepts a ticket which is either 1) a ticket with this prefix; or 2) a SQL query. - /// A valid SQL query can't start with this prefix so method DoGet() can distinguish those cases. - const String TICKET_PREFIX = "~TICKET-"; - - bool hasTicketPrefix(const String & ticket) - { - return ticket.starts_with(TICKET_PREFIX); - } - - /// We generate poll descriptors with this prefix. - /// Methods PollFlightInfo() or GetSchema() accept a flight descriptor which is either - /// 1) a normal flight descriptor (a table name or a SQL query); or 2) a poll descriptor with this prefix. - /// A valid SQL query can't start with this prefix so methods PollFlightInfo() and GetSchema() can distinguish those cases. - const String POLL_DESCRIPTOR_PREFIX = "~POLL-"; - - bool hasPollDescriptorPrefix(const String & poll_descriptor) - { - return poll_descriptor.starts_with(POLL_DESCRIPTOR_PREFIX); - } - - /// A ticket name with its expiration time. - struct TicketWithExpirationTime - { - String ticket; - /// When the ticket expires. - /// std::nullopt means that the ticket expires after using it in DoGet(). - /// Can be equal to ALREADY_EXPIRED. - std::optional expiration_time; - }; - - /// A poll descriptor's name with its expiration time. - struct PollDescriptorWithExpirationTime - { - String poll_descriptor; - /// When the poll descriptor expires. - /// std::nullopt means that the poll descriptor expires after using it in PollFlightInfo(); - /// Can be equal to ALREADY_EXPIRED. - std::optional expiration_time; - }; - - /// Keeps a block associated with a ticket. - struct TicketInfo : public TicketWithExpirationTime - { - ConstBlockPtr block; - std::shared_ptr ch_to_arrow_converter; - }; - - /// Information about a poll descriptor. - /// Objects of type PollDescriptorInfo are stored as a kind of a doubly linked list, - /// the previous object is stored as `previous_info`, and the next object is referenced by `next_poll_descriptor`. - struct PollDescriptorInfo : public PollDescriptorWithExpirationTime - { - std::shared_ptr ch_to_arrow_converter; - std::shared_ptr previous_info; - bool evaluating = false; - bool evaluated = false; - - /// The following fields can be set only if `evaluated == true`: - - /// A success or error error. - std::optional status; - - /// A new ticket. Along with tickets from previous infos (previous_info, previous_info->previous_info, etc.) - /// represents all tickets associated with this poll descriptor. - /// Can be unset if there is no block; or it can specify an already expired ticket. - std::optional ticket; - - /// Adds rows. Along with added rows from previous infos (previous_info, previous_info->previous_info, etc.) - /// represents the total number of rows associated with this poll descriptor. - /// Can be unset if there is no rows added. - std::optional rows; - - /// Adds bytes. Along with added bytes from previous infos (previous_info, previous_info->previous_info, etc.) - /// represents the total number of bytes associated with this poll descriptor. - /// Can be unset if there is no bytes added. - std::optional bytes; - - /// Next poll descriptor if any. - /// Can be unset if there is no next poll descriptor (no more blocks are to pull from the query pipeline). - std::optional next_poll_descriptor; - }; - - /// Keeps a query context and a pipeline executor for PollFlightInfo. - class PollSession - { - public: - PollSession( - std::unique_ptr session_, - ContextPtr query_context_, - ThreadGroupPtr thread_group_, - BlockIO && block_io_, - std::shared_ptr ch_to_arrow_converter_) - : session(std::move(session_)) - , query_context(query_context_) - , thread_group(thread_group_) - , block_io(std::move(block_io_)) - , executor(block_io.pipeline) - , ch_to_arrow_converter(ch_to_arrow_converter_) - { - } - - ~PollSession() = default; - - ThreadGroupPtr getThreadGroup() const { return thread_group; } - std::shared_ptr getCHToArrowConverter() const { return ch_to_arrow_converter; } - bool getNextBlock(Block & block) { return executor.pull(block); } - void onFinish() { block_io.onFinish(); } - void onException() { block_io.onException(); } - - private: - std::unique_ptr session; - ContextPtr query_context; - ThreadGroupPtr thread_group; - BlockIO block_io; - PullingPipelineExecutor executor; - std::shared_ptr ch_to_arrow_converter; - }; - - /// Creates a converter to convert ClickHouse blocks to the Arrow format. - std::shared_ptr createCHToArrowConverter(const Block & header) - { - CHColumnToArrowColumn::Settings arrow_settings; - arrow_settings.output_string_as_string = true; - auto ch_to_arrow_converter = std::make_shared(header, "Arrow", arrow_settings); - ch_to_arrow_converter->initializeArrowSchema(); - return ch_to_arrow_converter; - } -} - - -/// Keeps information about calls - e.g. blocks extracted from query pipelines, flight tickets, poll descriptors. -class ArrowFlightHandler::CallsData -{ -public: - CallsData(std::optional tickets_lifetime_, std::optional poll_descriptors_lifetime_, LoggerPtr log_) - : tickets_lifetime(tickets_lifetime_) - , poll_descriptors_lifetime(poll_descriptors_lifetime_) - , log(log_) - { - } - - /// Creates a flight ticket which allows to download a specified block. - std::shared_ptr createTicket(ConstBlockPtr block, std::shared_ptr ch_to_arrow_converter) - { - String ticket = generateTicketName(); - LOG_DEBUG(log, "Creating ticket {}", ticket); - auto expiration_time = calculateTicketExpirationTime(now()); - auto info = std::make_shared(); - info->ticket = ticket; - info->expiration_time = expiration_time; - info->block = block; - info->ch_to_arrow_converter = ch_to_arrow_converter; - std::lock_guard lock{mutex}; - bool inserted = tickets.try_emplace(ticket, info).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) - chassert(inserted); /// Flight tickets are unique. - if (expiration_time) - { - inserted = tickets_by_expiration_time.emplace(*expiration_time, ticket).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) - chassert(inserted); /// Flight tickets are unique. - updateNextExpirationTime(); - } - return info; - } - - [[nodiscard]] arrow::Result> getTicketInfo(const String & ticket) const - { - std::lock_guard lock{mutex}; - auto it = tickets.find(ticket); - if (it == tickets.end()) - return arrow::Status::KeyError("Ticket ", quoteString(ticket), " not found"); - return it->second; - } - - /// Finds the expiration time for a specified ticket. - /// If the ticket is not found it means it was expired and removed from the map. - std::optional getTicketExpirationTime(const String & ticket) const - { - if (!tickets_lifetime) - return std::nullopt; - std::lock_guard lock{mutex}; - auto it = tickets.find(ticket); - if (it == tickets.end()) - return ALREADY_EXPIRED; - return it->second->expiration_time; - } - - /// Extends the expiration time of a ticket. - /// The function calculates a new expiration time of a ticket based on the current time. - [[nodiscard]] arrow::Status extendTicketExpirationTime(const String & ticket) - { - if (!tickets_lifetime) - return arrow::Status::OK(); - std::lock_guard lock{mutex}; - auto it = tickets.find(ticket); - if (it == tickets.end()) - return arrow::Status::KeyError("Ticket ", quoteString(ticket), " not found"); - auto info = it->second; - auto old_expiration_time = info->expiration_time; - auto new_expiration_time = calculateTicketExpirationTime(now()); - auto new_info = std::make_shared(*info); - new_info->expiration_time = new_expiration_time; - it->second = new_info; - tickets_by_expiration_time.erase(std::make_pair(*old_expiration_time, ticket)); - tickets_by_expiration_time.emplace(*new_expiration_time, ticket); - updateNextExpirationTime(); - return arrow::Status::OK(); - } - - /// Cancels a ticket to free memory. - /// Tickets are cancelled either by timer (if setting "arrowflight.tickets_lifetime_seconds" > 0) - /// or after they are used by method DoGet (if setting "arrowflight.cancel_flight_descriptor_after_poll_flight_info" is set to true). - void cancelTicket(const String & ticket) - { - std::lock_guard lock{mutex}; - auto it = tickets.find(ticket); - if (it == tickets.end()) - return; /// The ticked has been already cancelled. - LOG_DEBUG(log, "Cancelling ticket {}", ticket); - auto info = it->second; - tickets.erase(it); - if (info->expiration_time) - { - tickets_by_expiration_time.erase(std::make_pair(*info->expiration_time, ticket)); - updateNextExpirationTime(); - } - } - - /// Creates a poll descriptor. - /// Poll descriptors are returned by method PollFlightInfo to get subsequent results from a long-running query. - std::shared_ptr - createPollDescriptor(std::unique_ptr poll_session, std::shared_ptr previous_info) - { - String poll_descriptor; - if (previous_info) - { - if (!previous_info->evaluated) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Adding a poll descriptor while the previous poll descriptor is not evaluated"); - if (!previous_info->next_poll_descriptor) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Adding a poll descriptor while the previous poll descriptor is final"); - poll_descriptor = *previous_info->next_poll_descriptor; - } - else - { - poll_descriptor = generatePollDescriptorName(); - } - LOG_DEBUG(log, "Creating poll descriptor {}", poll_descriptor); - auto current_time = now(); - auto expiration_time = calculatePollDescriptorExpirationTime(current_time); - auto info = std::make_shared(); - info->poll_descriptor = poll_descriptor; - info->expiration_time = expiration_time; - info->ch_to_arrow_converter = poll_session->getCHToArrowConverter(); - std::lock_guard lock{mutex}; - bool inserted = poll_descriptors.try_emplace(poll_descriptor, info).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) - chassert(inserted); /// Poll descriptors are unique. - inserted = poll_sessions.try_emplace(poll_descriptor, std::move(poll_session)).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) - chassert(inserted); /// Poll descriptors are unique. - if (expiration_time) - { - inserted = poll_descriptors_by_expiration_time.emplace(*expiration_time, poll_descriptor).second; /// NOLINT(clang-analyzer-deadcode.DeadStores) - chassert(inserted); /// Poll descriptors are unique. - updateNextExpirationTime(); - } - return info; - } - - [[nodiscard]] arrow::Result> getPollDescriptorInfo(const String & poll_descriptor) const - { - std::lock_guard lock{mutex}; - auto it = poll_descriptors.find(poll_descriptor); - if (it == poll_descriptors.end()) - return arrow::Status::KeyError("Poll descriptor ", quoteString(poll_descriptor), " not found"); - return it->second; - } - - /// Finds the expiration time for a specified poll descriptor. - /// If the poll descriptor is not found it means it was expired and removed from the map. - PollDescriptorWithExpirationTime getPollDescriptorWithExpirationTime(const String & poll_descriptor) const - { - if (!poll_descriptors_lifetime) - return PollDescriptorWithExpirationTime{.poll_descriptor = poll_descriptor, .expiration_time = std::nullopt}; - std::lock_guard lock{mutex}; - auto it = poll_descriptors.find(poll_descriptor); - if (it == poll_descriptors.end()) - return PollDescriptorWithExpirationTime{.poll_descriptor = poll_descriptor, .expiration_time = ALREADY_EXPIRED}; - return *it->second; - } - - /// Extends the expiration time of a poll descriptor. - /// The function calculates a new expiration time of a ticket based on the current time. - [[nodiscard]] arrow::Status extendPollDescriptorExpirationTime(const String & poll_descriptor) - { - if (!poll_descriptors_lifetime) - return arrow::Status::OK(); - auto current_time = now(); - std::lock_guard lock{mutex}; - auto it = poll_descriptors.find(poll_descriptor); - if (it == poll_descriptors.end()) - return arrow::Status::KeyError("Poll descriptor ", quoteString(poll_descriptor), " not found"); - auto info = it->second; - auto old_expiration_time = info->expiration_time; - auto new_expiration_time = calculatePollDescriptorExpirationTime(current_time); - auto new_info = std::make_shared(*info); - new_info->expiration_time = new_expiration_time; - it->second = new_info; - poll_descriptors_by_expiration_time.erase(std::make_pair(*old_expiration_time, poll_descriptor)); - poll_descriptors_by_expiration_time.emplace(*new_expiration_time, poll_descriptor); - updateNextExpirationTime(); - return arrow::Status::OK(); - } - - /// Starts evaluation (i.e. getting a block of data) for a specified poll descriptor. - /// The function returns nullptr if it's already evaluated. - /// If it's being evaluated at the moment in another thread the function waits until it finishes and then returns nullptr. - [[nodiscard]] arrow::Result> startEvaluation(const String & poll_descriptor) - { - arrow::Result> res; - std::unique_lock lock{mutex}; - evaluation_ended.wait(lock, [&]() TSA_REQUIRES(mutex) - { - auto it = poll_descriptors.find(poll_descriptor); - if (it == poll_descriptors.end()) - { - res = arrow::Status::KeyError("Poll descriptor ", quoteString(poll_descriptor), " not found"); - return true; - } - auto info = it->second; - if (info->evaluated) - { - res = std::unique_ptr{nullptr}; - return true; - } - if (!info->evaluating) - { - auto it2 = poll_sessions.find(poll_descriptor); - if (it2 == poll_sessions.end()) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Session is not attached to non-evaluated poll descriptor {}", poll_descriptor); - res = std::move(it2->second); - poll_sessions.erase(it2); - auto new_info = std::make_shared(*info); - new_info->evaluating = true; - it->second = new_info; - return true; - } - return false; /// The poll descriptor is being evaluating in another thread, we need to wait. - }); - return res; - } - - /// Ends evaluation for a specified poll descriptor. - void endEvaluation(const String & poll_descriptor, const String & ticket, UInt64 rows, UInt64 bytes, bool last) - { - std::lock_guard lock{mutex}; - auto it = poll_descriptors.find(poll_descriptor); - if (it == poll_descriptors.end()) - { - /// The poll descriptor expired during the query execution. - return; - } - - auto info = it->second; - if (info->evaluated) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Poll descriptor can't be evaluated twice"); - - auto new_info = std::make_shared(*info); - new_info->evaluating = false; - new_info->evaluated = true; - new_info->status = arrow::Status::OK(); - new_info->ticket = ticket; - new_info->rows = rows; - new_info->bytes = bytes; - if (!last) - new_info->next_poll_descriptor = generatePollDescriptorName(); - it->second = new_info; - info = new_info; - evaluation_ended.notify_all(); - } - - /// Ends evaluation for a specified poll descriptor with an error. - void endEvaluationWithError(const String & poll_descriptor, const arrow::Status & error_status) - { - chassert(!error_status.ok()); - std::lock_guard lock{mutex}; - auto it = poll_descriptors.find(poll_descriptor); - if (it != poll_descriptors.end()) - { - auto info = it->second; - if (!info->evaluated) - { - auto new_info = std::make_shared(*info); - new_info->evaluating = false; - new_info->evaluated = true; - new_info->status = error_status; - it->second = new_info; - info = new_info; - evaluation_ended.notify_all(); - } - } - } - - /// Cancels a poll descriptor to free memory. - /// Poll descriptors are cancelled either by timer (if setting "arrowflight.poll_descriptors_lifetime_seconds" > 0) - /// or after they are used by method PollFlightInfo (if setting "arrowflight.cancel_ticket_after_do_get" is set to true). - void cancelPollDescriptor(const String & poll_descriptor) - { - std::lock_guard lock{mutex}; - auto it = poll_descriptors.find(poll_descriptor); - if (it != poll_descriptors.end()) - { - LOG_DEBUG(log, "Cancelling poll descriptor {}", poll_descriptor); - auto info = it->second; - poll_descriptors.erase(it); - if (info->expiration_time) - { - poll_descriptors_by_expiration_time.erase(std::make_pair(*info->expiration_time, poll_descriptor)); - updateNextExpirationTime(); - } - } - auto it2 = poll_sessions.find(poll_descriptor); - if (it2 != poll_sessions.end()) - poll_sessions.erase(it2); - } - - /// Cancels tickets and poll descriptors if the current time is greater than their expiration time. - void cancelExpired() - { - auto current_time = now(); - std::lock_guard lock{mutex}; - while (!tickets_by_expiration_time.empty()) - { - auto it = tickets_by_expiration_time.begin(); - if (current_time <= it->first) - break; - LOG_DEBUG(log, "Cancelling expired ticket {}", it->second); - tickets.erase(it->second); - tickets_by_expiration_time.erase(it); - } - while (!poll_descriptors_by_expiration_time.empty()) - { - auto it = poll_descriptors_by_expiration_time.begin(); - if (current_time <= it->first) - break; - LOG_DEBUG(log, "Cancelling expired poll descriptor {}", it->second); - poll_descriptors.erase(it->second); - poll_sessions.erase(it->second); - poll_descriptors_by_expiration_time.erase(it); - } - updateNextExpirationTime(); - } - - /// Waits until maybe it's time to cancel expired tickets or poll descriptors. - void waitNextExpirationTime() const - { - auto current_time = now(); - std::unique_lock lock{mutex}; - auto expiration_time = next_expiration_time; - auto is_ready = [&] - { - if (stop_waiting_next_expiration_time) - return true; - if (next_expiration_time != expiration_time) - return true; /// We need to restart waiting if the next expiration time has changed. - current_time = now(); - return (expiration_time && (current_time > *expiration_time)); - }; - if (expiration_time) - { - if (current_time < *expiration_time) - next_expiration_time_updated.wait_for(lock, *expiration_time - current_time, is_ready); - } - else - { - next_expiration_time_updated.wait(lock, is_ready); - } - } - - void stopWaitingNextExpirationTime() - { - std::lock_guard lock{mutex}; - stop_waiting_next_expiration_time = true; - next_expiration_time_updated.notify_all(); - } - -private: - static String generateTicketName() - { - return TICKET_PREFIX + toString(UUIDHelpers::generateV4()); - } - - static String generatePollDescriptorName() - { - return POLL_DESCRIPTOR_PREFIX + toString(UUIDHelpers::generateV4()); - } - - std::optional calculateTicketExpirationTime(Timestamp current_time) const - { - if (!tickets_lifetime) - return std::nullopt; - return current_time + *tickets_lifetime; - } - - std::optional calculatePollDescriptorExpirationTime(Timestamp current_time) const - { - if (!poll_descriptors_lifetime) - return std::nullopt; - return current_time + *poll_descriptors_lifetime; - } - - void updateNextExpirationTime() TSA_REQUIRES(mutex) - { - auto expiration_time = next_expiration_time; - next_expiration_time.reset(); - if (!tickets_by_expiration_time.empty()) - next_expiration_time = tickets_by_expiration_time.begin()->first; - if (!poll_descriptors_by_expiration_time.empty()) - { - auto other_expiration_time = poll_descriptors_by_expiration_time.begin()->first; - next_expiration_time = next_expiration_time ? std::min(*next_expiration_time, other_expiration_time) : other_expiration_time; - } - if (next_expiration_time != expiration_time) - next_expiration_time_updated.notify_all(); - } - - const std::optional tickets_lifetime; - const std::optional poll_descriptors_lifetime; - const LoggerPtr log; - mutable std::mutex mutex; - std::unordered_map> tickets TSA_GUARDED_BY(mutex); - std::unordered_map> poll_descriptors TSA_GUARDED_BY(mutex); - std::unordered_map> poll_sessions TSA_GUARDED_BY(mutex); - std::condition_variable evaluation_ended; - /// `tickets_by_expiration_time` and `poll_descriptors_by_expiration_time` are sorted by `expiration_time` so `std::set` is used. - std::set> tickets_by_expiration_time TSA_GUARDED_BY(mutex); - std::set> poll_descriptors_by_expiration_time TSA_GUARDED_BY(mutex); - std::optional next_expiration_time; - mutable std::condition_variable next_expiration_time_updated; - bool stop_waiting_next_expiration_time = false; -}; - - -ArrowFlightHandler::ArrowFlightHandler(IServer & server_, const Poco::Net::SocketAddress & address_to_listen_) - : server(server_) - , log(getLogger("ArrowFlightHandler")) - , address_to_listen(address_to_listen_) - , tickets_lifetime_seconds(server.config().getUInt("arrowflight.tickets_lifetime_seconds", 600)) - , cancel_ticket_after_do_get(server.config().getBool("arrowflight.cancel_ticket_after_do_get", false)) - , poll_descriptors_lifetime_seconds(server.config().getUInt("arrowflight.poll_descriptors_lifetime_seconds", 600)) - , cancel_poll_descriptor_after_poll_flight_info(server.config().getBool("arrowflight.cancel_flight_descriptor_after_poll_flight_info", false)) - , calls_data( - std::make_unique( - tickets_lifetime_seconds ? std::make_optional(std::chrono::seconds{tickets_lifetime_seconds}) : std::optional{}, - poll_descriptors_lifetime_seconds ? std::make_optional(std::chrono::seconds{poll_descriptors_lifetime_seconds}) - : std::optional{}, - log)) -{ -} - -void ArrowFlightHandler::start() -{ - chassert(!initialized && !stopped); - - bool use_tls = server.config().getBool("arrowflight.enable_ssl", false); - - auto location = addressToArrowLocation(address_to_listen, use_tls); - - arrow::flight::FlightServerOptions options(location); - options.auth_handler = std::make_unique(); - options.middleware.emplace_back(AUTHORIZATION_MIDDLEWARE_NAME, std::make_shared()); - - if (use_tls) - { - auto cert_path = server.config().getString("arrowflight.ssl_cert_file"); - auto key_path = server.config().getString("arrowflight.ssl_key_file"); - - auto cert = readFile(cert_path); - auto key = readFile(key_path); - - options.tls_certificates.push_back(arrow::flight::CertKeyPair{cert, key}); - } - - auto init_status = Init(options); - if (!init_status.ok()) - { - throw Exception(ErrorCodes::UNKNOWN_EXCEPTION, "Failed init Arrow Flight Server: {}", init_status.ToString()); - } - - initialized = true; - - server_thread.emplace([this] - { - try - { - DB::setThreadName(ThreadName::ARROW_FLIGHT_SERVER); - if (stopped) - return; - auto serve_status = Serve(); - if (!serve_status.ok()) - LOG_ERROR(log, "Failed to serve Arrow Flight: {}", serve_status.ToString()); - } - catch (...) - { - tryLogCurrentException(log, "Failed to serve Arrow Flight"); - } - }); - - if (tickets_lifetime_seconds || poll_descriptors_lifetime_seconds) - { - cleanup_thread.emplace([this] - { - try - { - DB::setThreadName(ThreadName::ARROW_FLIGHT_EXPR); - while (!stopped) - { - calls_data->waitNextExpirationTime(); - calls_data->cancelExpired(); - } - } - catch (...) - { - tryLogCurrentException(log, "Failed to cleanup"); - } - }); - } -} - -ArrowFlightHandler::~ArrowFlightHandler() = default; - -void ArrowFlightHandler::stop() -{ - if (!initialized) - return; - - if (!stopped.exchange(true)) - { - try - { - auto status = Shutdown(); - if (!status.ok()) - LOG_ERROR(log, "Failed to shutdown Arrow Flight: {}", status.ToString()); - } - catch (...) - { - tryLogCurrentException(log, "Failed to shutdown Arrow Flight"); - } - if (server_thread) - { - server_thread->join(); - server_thread.reset(); - } - - calls_data->stopWaitingNextExpirationTime(); - if (cleanup_thread) - { - cleanup_thread->join(); - cleanup_thread.reset(); - } - } -} - -UInt16 ArrowFlightHandler::portNumber() const -{ - return address_to_listen.port(); -} - -arrow::Status ArrowFlightHandler::GetFlightInfo( - const arrow::flight::ServerCallContext & context, - const arrow::flight::FlightDescriptor & request, - std::unique_ptr * info) -{ - auto impl = [&] - { - LOG_INFO(log, "GetFlightInfo is called for descriptor {}", request.ToString()); - - std::vector endpoints; - int64_t total_rows = 0; - int64_t total_bytes = 0; - std::shared_ptr ch_to_arrow_converter; - - if ((request.type == arrow::flight::FlightDescriptor::CMD) && hasPollDescriptorPrefix(request.cmd)) - { - return arrow::Status::Invalid("Method GetFlightInfo cannot be called with a flight descriptor returned by method PollFlightInfo"); - } - else - { - auto sql_res = convertGetDescriptorToSQL(request); - ARROW_RETURN_NOT_OK(sql_res); - const String & sql = sql_res.ValueOrDie(); - - Session session{server.context(), ClientInfo::Interface::ARROW_FLIGHT}; - - const auto & auth = AuthMiddleware::get(context); - session.authenticate(auth.username(), auth.password(), getClientAddress(context)); - - auto query_context = session.makeQueryContext(); - query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. - QueryScope query_scope = QueryScope::create(query_context); - - auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); - try - { - ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); - ARROW_RETURN_NOT_OK(checkPipelineIsPulling(block_io.pipeline)); - - PullingPipelineExecutor executor{block_io.pipeline}; - ch_to_arrow_converter = createCHToArrowConverter(executor.getHeader()); - - Block block; - while (executor.pull(block)) - { - if (!block.empty()) - { - total_rows += block.rows(); - total_bytes += block.bytes(); - auto ticket_info = calls_data->createTicket(std::make_shared(std::move(block)), ch_to_arrow_converter); - arrow::flight::FlightEndpoint endpoint; - endpoint.ticket = arrow::flight::Ticket{ticket_info->ticket}; - endpoint.expiration_time = ticket_info->expiration_time; - endpoints.emplace_back(endpoint); - } - } - block_io.onFinish(); - } - catch (...) - { - block_io.onException(); - throw; - } - } - - auto schema = ch_to_arrow_converter->getArrowSchema(); - - auto flight_info_res = arrow::flight::FlightInfo::Make( - *schema, - request, - endpoints, - total_rows, - total_bytes, - /* ordered = */ true); - - ARROW_RETURN_NOT_OK(flight_info_res); - *info = std::make_unique(std::move(flight_info_res).ValueOrDie()); - - LOG_INFO(log, "GetFlightInfo returns flight info {}", (*info)->ToString()); - return arrow::Status::OK(); - }; - return tryRunAndLogIfError("GetFlightInfo", impl); -} - - -arrow::Status ArrowFlightHandler::GetSchema( - const arrow::flight::ServerCallContext & context, - const arrow::flight::FlightDescriptor & request, - std::unique_ptr * schema) -{ - auto impl = [&] - { - LOG_INFO(log, "GetSchema is called for descriptor {}", request.ToString()); - std::shared_ptr ch_to_arrow_converter; - - if ((request.type == arrow::flight::FlightDescriptor::CMD) && hasPollDescriptorPrefix(request.cmd)) - { - const String & poll_descriptor = request.cmd; - ARROW_RETURN_NOT_OK(calls_data->extendPollDescriptorExpirationTime(poll_descriptor)); - auto poll_info_res = calls_data->getPollDescriptorInfo(poll_descriptor); - ARROW_RETURN_NOT_OK(poll_info_res); - const auto & poll_info = poll_info_res.ValueOrDie(); - ch_to_arrow_converter = poll_info->ch_to_arrow_converter; - } - else - { - auto sql_res = convertGetDescriptorToSQL(request); - ARROW_RETURN_NOT_OK(sql_res); - const String & sql = sql_res.ValueOrDie(); - - Session session{server.context(), ClientInfo::Interface::ARROW_FLIGHT}; - - const auto & auth = AuthMiddleware::get(context); - session.authenticate(auth.username(), auth.password(), getClientAddress(context)); - - auto query_context = session.makeQueryContext(); - query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. - QueryScope query_scope = QueryScope::create(query_context); - - auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); - ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); - ARROW_RETURN_NOT_OK(checkPipelineIsPulling(block_io.pipeline)); - - ch_to_arrow_converter = createCHToArrowConverter(block_io.pipeline.getHeader()); - } - - auto schema_res = arrow::flight::SchemaResult::Make(*ch_to_arrow_converter->getArrowSchema()); - ARROW_RETURN_NOT_OK(schema_res); - *schema = std::make_unique(*std::move(schema_res).ValueOrDie()); - - LOG_INFO(log, "GetSchema returns schema {}", ch_to_arrow_converter->getArrowSchema()->ToString()); - return arrow::Status::OK(); - }; - return tryRunAndLogIfError("GetSchema", impl); -} - - -arrow::Status ArrowFlightHandler::PollFlightInfo( - const arrow::flight::ServerCallContext & context, - const arrow::flight::FlightDescriptor & request, - std::unique_ptr * info) -{ - auto impl = [&] - { - LOG_INFO(log, "PollFlightInfo is called for descriptor {}", request.ToString()); - - std::shared_ptr poll_info; - std::shared_ptr ch_to_arrow_converter; - std::optional next_poll_descriptor; - bool should_cancel_poll_descriptor = false; - - if ((request.type == arrow::flight::FlightDescriptor::CMD) && hasPollDescriptorPrefix(request.cmd)) - { - const String & poll_descriptor = request.cmd; - ARROW_RETURN_NOT_OK(evaluatePollDescriptor(poll_descriptor)); - ARROW_RETURN_NOT_OK(calls_data->extendPollDescriptorExpirationTime(poll_descriptor)); - auto poll_info_res = calls_data->getPollDescriptorInfo(poll_descriptor); - ARROW_RETURN_NOT_OK(poll_info_res); - poll_info = poll_info_res.ValueOrDie(); - ch_to_arrow_converter = poll_info->ch_to_arrow_converter; - if (poll_info->next_poll_descriptor) - next_poll_descriptor = calls_data->getPollDescriptorWithExpirationTime(*poll_info->next_poll_descriptor); - should_cancel_poll_descriptor = cancel_poll_descriptor_after_poll_flight_info; - } - else - { - auto sql_res = convertGetDescriptorToSQL(request); - ARROW_RETURN_NOT_OK(sql_res); - const String & sql = sql_res.ValueOrDie(); - - auto session = std::make_unique(server.context(), ClientInfo::Interface::ARROW_FLIGHT); - - const auto & auth = AuthMiddleware::get(context); - session->authenticate(auth.username(), auth.password(), getClientAddress(context)); - - auto query_context = session->makeQueryContext(); - query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. - - auto thread_group = ThreadGroup::createForQuery(query_context); - CurrentThread::attachToGroup(thread_group); - - auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); - try - { - ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); - ARROW_RETURN_NOT_OK(checkPipelineIsPulling(block_io.pipeline)); - - ch_to_arrow_converter = createCHToArrowConverter(block_io.pipeline.getHeader()); - - auto poll_session = std::make_unique(std::move(session), query_context, thread_group, std::move(block_io), - ch_to_arrow_converter); - - auto next_info = calls_data->createPollDescriptor(std::move(poll_session), /* previous_info = */ nullptr); - next_poll_descriptor = *next_info; - } - catch (...) - { - block_io.onException(); - throw; - } - } - - std::vector endpoints; - int64_t total_rows = 0; - int64_t total_bytes = 0; - - while (poll_info) - { - if (poll_info->ticket) - { - arrow::flight::FlightEndpoint endpoint; - endpoint.ticket = arrow::flight::Ticket{*poll_info->ticket}; - endpoint.expiration_time = calls_data->getTicketExpirationTime(*poll_info->ticket); - endpoints.emplace_back(endpoint); - } - if (poll_info->rows) - total_rows += *poll_info->rows; - if (poll_info->bytes) - total_bytes += *poll_info->bytes; - poll_info = poll_info->previous_info; - } - std::reverse(endpoints.begin(), endpoints.end()); - - std::unique_ptr flight_info; - if (!endpoints.empty()) - { - auto flight_info_res = arrow::flight::FlightInfo::Make(*ch_to_arrow_converter->getArrowSchema(), request, endpoints, total_rows, total_bytes, /* ordered = */ true); - ARROW_RETURN_NOT_OK(flight_info_res); - flight_info = std::make_unique(flight_info_res.ValueOrDie()); - } - - std::optional next; - std::optional expiration_time; - if (next_poll_descriptor) - { - next = arrow::flight::FlightDescriptor::Command(next_poll_descriptor->poll_descriptor); - expiration_time = next_poll_descriptor->expiration_time; - } - - *info = std::make_unique(std::move(flight_info), std::move(next), std::nullopt, expiration_time); - - if (should_cancel_poll_descriptor) - calls_data->cancelPollDescriptor(request.cmd); - - LOG_INFO(log, "PollFlightInfo returns {}", (*info)->ToString()); - return arrow::Status::OK(); - }; - return tryRunAndLogIfError("PollFlightInfo", impl); -} - - -/// evaluatePollDescriptors() pulls a block from the query pipeline. -/// This function blocks until it either gets a nonempty block from the query pipeline or finds out that there will be no blocks anymore. -/// -/// NOTE: The current implementation doesn't allow to set a timeout to avoid blocking calls as it's suggested in the documentation -/// for PollFlightInfo (see https://arrow.apache.org/docs/format/Flight.html#downloading-data-by-running-a-heavy-query). -arrow::Status ArrowFlightHandler::evaluatePollDescriptor(const String & poll_descriptor) -{ - auto poll_session_res = calls_data->startEvaluation(poll_descriptor); - ARROW_RETURN_NOT_OK(poll_session_res); - auto poll_session = std::move(poll_session_res).ValueOrDie(); - - if (!poll_session) - { - /// Already evaluated. - auto info_res = calls_data->getPollDescriptorInfo(poll_descriptor); - ARROW_RETURN_NOT_OK(info_res); - const auto & info = info_res.ValueOrDie(); - if (!info->evaluated) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Session is not attached to non-evaluated poll descriptor {}", poll_descriptor); - return *info->status; - } - - ThreadGroupSwitcher thread_group_switcher{poll_session->getThreadGroup(), ThreadName::ARROW_FLIGHT}; - auto ch_to_arrow_converter = poll_session->getCHToArrowConverter(); - bool last = false; - - try - { - String ticket; - UInt64 rows = 0; - UInt64 bytes = 0; - Block block; - if (poll_session->getNextBlock(block)) - { - if (!block.empty()) - { - rows = block.rows(); - bytes = block.bytes(); - auto ticket_info = calls_data->createTicket(std::make_shared(std::move(block)), ch_to_arrow_converter); - ticket = ticket_info->ticket; - } - } - else - { - last = true; - } - - calls_data->endEvaluation(poll_descriptor, ticket, rows, bytes, last); - poll_session->onFinish(); - } - catch (...) - { - tryLogCurrentException(log, "Poll: Failed to get next block"); - auto error_status = arrow::Status::ExecutionError("Poll: Failed to get next block: ", getCurrentExceptionMessage(/* with_stacktrace = */ false)); - calls_data->endEvaluationWithError(poll_descriptor, error_status); - poll_session->onException(); - return error_status; - } - - auto info_res = calls_data->getPollDescriptorInfo(poll_descriptor); - ARROW_RETURN_NOT_OK(info_res); - const auto & info = info_res.ValueOrDie(); - if (!last) - calls_data->createPollDescriptor(std::move(poll_session), info); - - return arrow::Status::OK(); -} - - -arrow::Status ArrowFlightHandler::DoGet( - const arrow::flight::ServerCallContext & context, - const arrow::flight::Ticket & request, - std::unique_ptr * stream) -{ - auto impl = [&] - { - LOG_INFO(log, "DoGet is called for ticket {}", request.ticket); - - Block header; - std::vector chunks; - std::shared_ptr ch_to_arrow_converter; - bool should_cancel_ticket = false; - - if (hasTicketPrefix(request.ticket)) - { - auto ticket_info_res = calls_data->getTicketInfo(request.ticket); - ARROW_RETURN_NOT_OK(ticket_info_res); - const auto & ticket_info = ticket_info_res.ValueOrDie(); - chunks.emplace_back(Chunk(ticket_info->block->getColumns(), ticket_info->block->rows())); - header = ticket_info->block->cloneEmpty(); - ch_to_arrow_converter = ticket_info->ch_to_arrow_converter->clone(/* copy_arrow_schema = */ true); - should_cancel_ticket = cancel_ticket_after_do_get; - } - else - { - const String & sql = request.ticket; - - Session session{server.context(), ClientInfo::Interface::ARROW_FLIGHT}; - - const auto & auth = AuthMiddleware::get(context); - session.authenticate(auth.username(), auth.password(), getClientAddress(context)); - - auto query_context = session.makeQueryContext(); - query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. - QueryScope query_scope = QueryScope::create(query_context); - - auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); - try - { - ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); - ARROW_RETURN_NOT_OK(checkPipelineIsPulling(block_io.pipeline)); - - PullingPipelineExecutor executor{block_io.pipeline}; - - Block block; - while (executor.pull(block)) - chunks.emplace_back(Chunk(block.getColumns(), block.rows())); - - header = executor.getHeader(); - ch_to_arrow_converter = createCHToArrowConverter(header); - block_io.onFinish(); - } - catch (...) - { - block_io.onException(); - throw; - } - } - - std::shared_ptr arrow_table; - ch_to_arrow_converter->chChunkToArrowTable(arrow_table, chunks, header.columns()); - - auto stream_res = arrow::RecordBatchReader::MakeFromIterator( - arrow::Iterator>{arrow::TableBatchReader{arrow_table}}, - ch_to_arrow_converter->getArrowSchema()); - ARROW_RETURN_NOT_OK(stream_res); - *stream = std::make_unique(stream_res.ValueOrDie()); - - if (should_cancel_ticket) - calls_data->cancelTicket(request.ticket); - - LOG_INFO(log, "DoGet succeeded"); - return arrow::Status::OK(); - }; - return tryRunAndLogIfError("DoGet", impl); -} - - -arrow::Status ArrowFlightHandler::DoPut( - const arrow::flight::ServerCallContext & context, - std::unique_ptr reader, - std::unique_ptr /*writer*/) -{ - auto impl = [&] - { - const auto & descriptor = reader->descriptor(); - LOG_INFO(log, "DoPut is called for descriptor {}", descriptor.ToString()); - - auto sql_res = convertPutDescriptorToSQL(descriptor); - ARROW_RETURN_NOT_OK(sql_res); - const String & sql = sql_res.ValueOrDie(); - - Session session{server.context(), ClientInfo::Interface::ARROW_FLIGHT}; - - const auto & auth = AuthMiddleware::get(context); - session.authenticate(auth.username(), auth.password(), getClientAddress(context)); - - auto query_context = session.makeQueryContext(); - query_context->setCurrentQueryId(""); /// Empty string means the query id will be autogenerated. - QueryScope query_scope = QueryScope::create(query_context); - - auto [ast, block_io] = executeQuery(sql, query_context, QueryFlags{}, QueryProcessingStage::Complete); - try - { - ARROW_RETURN_NOT_OK(checkNoCustomFormat(ast)); - auto & pipeline = block_io.pipeline; - - if (pipeline.pushing()) - { - Block header = pipeline.getHeader(); - auto input = std::make_shared(std::move(reader), header); - pipeline.complete(Pipe(std::move(input))); - } - else if (pipeline.pulling()) - { - Block header = pipeline.getHeader(); - auto output = std::make_shared(std::make_shared(header)); - pipeline.complete(std::move(output)); - } - - CompletedPipelineExecutor executor(pipeline); - executor.execute(); - LOG_INFO(log, "DoPut succeeded"); - block_io.onFinish(); - } - catch (...) - { - block_io.onException(); - throw; - } - - return arrow::Status::OK(); - }; - return tryRunAndLogIfError("DoPut", impl); -} - - -arrow::Status ArrowFlightHandler::tryRunAndLogIfError(std::string_view method_name, std::function && func) const -{ - DB::setThreadName(ThreadName::ARROW_FLIGHT); - ThreadStatus thread_status; - try - { - auto status = std::move(func)(); - if (!status.ok()) - LOG_ERROR(log, "{} failed: {}", method_name, status.ToString()); - return status; - } - catch (...) - { - tryLogCurrentException(log, fmt::format("{} failed", method_name)); - return arrow::Status::ExecutionError(method_name, " failed: ", getCurrentExceptionMessage(/* with_stacktrace = */ false)); - } -} - - -arrow::Status ArrowFlightHandler::DoAction( - const arrow::flight::ServerCallContext & /*context*/, - const arrow::flight::Action & /*action*/, - std::unique_ptr * /*result*/) -{ - return arrow::Status::NotImplemented("NYI"); -} - -} - -#endif From 7a1bbe1e37103e63f4cfadf15b382d8de26606f1 Mon Sep 17 00:00:00 2001 From: "Nihal Z. Miaji" <81457724+nihalzp@users.noreply.github.com> Date: Wed, 15 Apr 2026 17:09:10 +0000 Subject: [PATCH 3/4] Cherry-pick of https://github.com/ClickHouse/ClickHouse/pull/101272 with unresolved conflict markers (resolution in next commit) --- Original cherry-pick message follows: Merge pull request #101272 from nihalzp/support-arrow-orc-nullable-tuple Support `Nullable(Tuple)` for `Arrow`, `ArrowStream`, `ORC`, legacy `Parquet` formats # Conflicts: # src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp # tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.reference # tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.sql --- src/DataTypes/NestedUtils.cpp | 73 ++ src/DataTypes/NestedUtils.h | 7 + src/Formats/insertNullAsDefaultIfNeeded.cpp | 26 + .../Formats/Impl/ArrowColumnToCHColumn.cpp | 30 +- .../Formats/Impl/CHColumnToArrowColumn.cpp | 10 +- .../Impl/NativeORCBlockInputFormat.cpp | 24 +- .../Formats/Impl/ORCBlockOutputFormat.cpp | 22 +- ...llable_low_cardinality_as_dict_in_arrow.sh | 12 + ...lable_low_cardinality_as_dict_in_arrow.sql | 8 - ...s_nullable_empty_tuple_roundtrip.reference | 186 +++ ...formats_nullable_empty_tuple_roundtrip.sql | 115 ++ ...ide_nullable_arrow_orc_roundtrip.reference | 456 +++++++ ...le_inside_nullable_arrow_orc_roundtrip.sql | 399 +++++++ ...nside_nullable_parquet_roundtrip.reference | 251 ++++ ...uple_inside_nullable_parquet_roundtrip.sql | 275 +++++ tmp/source_pr_chcolumn.diff | 1048 +++++++++++++++++ 16 files changed, 2918 insertions(+), 24 deletions(-) create mode 100755 tests/queries/0_stateless/02384_nullable_low_cardinality_as_dict_in_arrow.sh delete mode 100644 tests/queries/0_stateless/02384_nullable_low_cardinality_as_dict_in_arrow.sql create mode 100644 tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.reference create mode 100644 tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.sql create mode 100644 tests/queries/0_stateless/04064_tuple_inside_nullable_arrow_orc_roundtrip.reference create mode 100644 tests/queries/0_stateless/04064_tuple_inside_nullable_arrow_orc_roundtrip.sql create mode 100644 tests/queries/0_stateless/04065_tuple_inside_nullable_parquet_roundtrip.reference create mode 100644 tests/queries/0_stateless/04065_tuple_inside_nullable_parquet_roundtrip.sql create mode 100644 tmp/source_pr_chcolumn.diff diff --git a/src/DataTypes/NestedUtils.cpp b/src/DataTypes/NestedUtils.cpp index 56ebe66c2ecc..8e7ed78e88d9 100644 --- a/src/DataTypes/NestedUtils.cpp +++ b/src/DataTypes/NestedUtils.cpp @@ -7,11 +7,14 @@ #include #include +#include #include #include #include #include +#include +#include #include #include @@ -121,6 +124,76 @@ std::string extractTableName(const std::string & nested_name) } +ColumnWithTypeAndName unwrapNullableTuple(const ColumnWithTypeAndName & column) +{ + const auto * type_nullable = typeid_cast(column.type.get()); + if (!type_nullable) + return column; + + const auto * tuple_type = typeid_cast(type_nullable->getNestedType().get()); + if (!tuple_type) + return column; + + const auto & col_nullable = assert_cast(*column.column); + + const auto & null_map_data = col_nullable.getNullMapData(); + bool has_nulls = !memoryIsZero(null_map_data.data(), 0, null_map_data.size()); + + if (!has_nulls) + { + /// No actual nulls — just strip the Nullable wrapper. + return {col_nullable.getNestedColumnPtr(), type_nullable->getNestedType(), column.name}; + } + + /// Propagate the struct null map to each Tuple element. + const auto & inner_tuple = assert_cast(col_nullable.getNestedColumn()); + const auto & null_map_ptr = col_nullable.getNullMapColumnPtr(); + Columns new_elements; + DataTypes new_types; + for (size_t i = 0; i < tuple_type->getElements().size(); ++i) + { + auto elem_col = inner_tuple.getColumnPtr(i); + auto elem_type = tuple_type->getElement(i); + if (elem_type->isNullable()) + { + /// Element already Nullable — merge null maps (struct null OR element null). + const auto & existing = assert_cast(*elem_col); + auto merged = ColumnUInt8::create(null_map_ptr->size()); + const auto & s = assert_cast(*null_map_ptr).getData(); + const auto & e = existing.getNullMapData(); + auto & m = merged->getData(); + for (size_t j = 0; j < s.size(); ++j) + m[j] = s[j] | e[j]; + new_elements.push_back(ColumnNullable::create(existing.getNestedColumnPtr(), std::move(merged))); + new_types.push_back(elem_type); + } + else if (elem_type->canBeInsideNullable()) + { + new_elements.push_back(ColumnNullable::create(elem_col, null_map_ptr)); + new_types.push_back(std::make_shared(elem_type)); + } + else + { + /// Array, Map, etc. — replace values at null positions with type defaults. + const auto & nm = col_nullable.getNullMapData(); + auto mutable_col = elem_col->cloneEmpty(); + for (size_t j = 0; j < elem_col->size(); ++j) + { + if (nm[j]) + mutable_col->insertDefault(); + else + mutable_col->insertFrom(*elem_col, j); + } + new_elements.push_back(std::move(mutable_col)); + new_types.push_back(elem_type); + } + } + + auto result_type = tuple_type->hasExplicitNames() ? std::make_shared(std::move(new_types), tuple_type->getElementNames()) + : std::make_shared(std::move(new_types)); + return {ColumnTuple::create(std::move(new_elements)), result_type, column.name}; +} + static Block flattenImpl(const Block & block, bool flatten_named_tuple) { Block res; diff --git a/src/DataTypes/NestedUtils.h b/src/DataTypes/NestedUtils.h index c358cb46edcf..8ee706276704 100644 --- a/src/DataTypes/NestedUtils.h +++ b/src/DataTypes/NestedUtils.h @@ -62,6 +62,13 @@ namespace Nested /// Convert old-style nested (single arrays with same prefix, `n.a`, `n.b`...) to subcolumns of data type Nested. NamesAndTypesList convertToSubcolumns(const NamesAndTypesList & names_and_types); + /// Unwrap Nullable(Tuple(...)) into Tuple(...) by propagating the struct-level null map + /// to each element. Scalar elements become Nullable(T), already-Nullable elements get merged + /// null maps, and non-nullable-compatible elements (Array, Map) get defaults at null positions. + /// When there are no actual nulls, simply strips the Nullable wrapper. + /// Used by format readers (Arrow, ORC) to convert Nullable struct elements for Nested flattening. + ColumnWithTypeAndName unwrapNullableTuple(const ColumnWithTypeAndName & column); + /// Check that sizes of arrays - elements of nested data structures - are equal. void validateArraySizes(const Block & block); diff --git a/src/Formats/insertNullAsDefaultIfNeeded.cpp b/src/Formats/insertNullAsDefaultIfNeeded.cpp index 11162303c264..d62719375d61 100644 --- a/src/Formats/insertNullAsDefaultIfNeeded.cpp +++ b/src/Formats/insertNullAsDefaultIfNeeded.cpp @@ -94,6 +94,27 @@ bool insertNullAsDefaultIfNeeded(ColumnWithTypeAndName & input_column, const Col return true; } + /// When both input and header are Nullable, unwrap and recurse into the nested types. + /// This can handle cases such as e.g. Nullable(Tuple(Nullable(Int32), String)) vs Nullable(Tuple(UInt32, String)) + if (input_column.type->isNullable() && header_column.type->isNullable()) + { + ColumnWithTypeAndName nested_input; + nested_input.column = assert_cast(input_column.column.get())->getNestedColumnPtr(); + nested_input.type = removeNullable(input_column.type); + + ColumnWithTypeAndName nested_header; + nested_header.column = assert_cast(header_column.column.get())->getNestedColumnPtr(); + nested_header.type = removeNullable(header_column.type); + + if (!insertNullAsDefaultIfNeeded(nested_input, nested_header, 0, nullptr)) + return false; + + input_column.column = ColumnNullable::create( + nested_input.column, assert_cast(input_column.column.get())->getNullMapColumnPtr()); + input_column.type = std::make_shared(std::move(nested_input.type)); + return true; + } + if (!isNullableOrLowCardinalityNullable(input_column.type) || isNullableOrLowCardinalityNullable(header_column.type)) return false; @@ -118,6 +139,11 @@ bool insertNullAsDefaultIfNeeded(ColumnWithTypeAndName & input_column, const Col input_column.type = std::make_shared(removeNullable(lc_type->getDictionaryType())); } + /// After stripping the outer Nullable, the inner type may also need processing. + /// For example, Nullable(Tuple(Nullable(Int), String)) -> Tuple(Nullable(Int), String) + /// still needs the Tuple elements compared against the header to strip inner Nullable. + insertNullAsDefaultIfNeeded(input_column, header_column, column_i, block_missing_values); + return true; } diff --git a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp index 1ed34febcd15..7a41c4901969 100644 --- a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp +++ b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp @@ -105,6 +105,18 @@ static ColumnWithTypeAndName readColumnWithNumericData(const std::shared_ptr buffer = chunk->data()->buffers[1]; const auto * raw_data = reinterpret_cast(buffer->data()) + chunk->offset(); column_data.insert_assume_reserved(raw_data, raw_data + chunk->length()); + + /// Values at null positions are not guaranteed to be initialized in the source buffer. + /// Zero them out because downstream code (type conversions, serialization) may read all values. + if (chunk->null_count() > 0) + { + size_t start = column_data.size() - chunk->length(); + for (int64_t i = 0; i < chunk->length(); ++i) + { + if (chunk->IsNull(i)) + column_data[start + i] = {}; + } + } } return {std::move(internal_column), std::move(internal_type), column_name}; } @@ -1160,22 +1172,27 @@ static ColumnWithTypeAndName readNonNullableColumnFromArrowColumn( return readOffsetsFromArrowListColumn(arrow_column); } }(); - auto array_column = ColumnArray::create(nested_column.column, offsets_column); - DataTypePtr array_type; - /// If type hint is Nested, we should return Nested type, - /// because we differentiate Nested and simple Array(Tuple) + ColumnPtr array_data_column = nested_column.column; + /// If type hint is Nested and the element is a named Tuple, return the Nested type + /// so that `Nested::flatten` can decompose it into separate arrays. + /// When the element is Nullable(Tuple(...)) (e.g. from Arrow's default nullable schema), + /// unwrap it and propagate the struct null map to each element via `unwrapNullableTuple`. const auto * tuple_type = type_hint && isNested(type_hint) ? typeid_cast(removeNullable(nested_column.type).get()) : nullptr; if (tuple_type) { - array_type = createNested(tuple_type->getElements(), tuple_type->getElementNames()); + auto unwrapped = Nested::unwrapNullableTuple({array_data_column, nested_column.type, column_name}); + array_data_column = unwrapped.column; + const auto & result_tuple = assert_cast(*unwrapped.type); + array_type = createNested(result_tuple.getElements(), result_tuple.getElementNames()); } else { array_type = std::make_shared(nested_column.type); } + auto array_column = ColumnArray::create(array_data_column, offsets_column); return {std::move(array_column), array_type, column_name}; } case arrow::Type::STRUCT: @@ -1408,7 +1425,10 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( arrow_column->type()->id() != arrow::Type::LARGE_LIST && arrow_column->type()->id() != arrow::Type::FIXED_SIZE_LIST && arrow_column->type()->id() != arrow::Type::MAP && +<<<<<<< HEAD arrow_column->type()->id() != arrow::Type::STRUCT && +======= +>>>>>>> fc17de3cb80 (Merge pull request #101272 from nihalzp/support-arrow-orc-nullable-tuple) arrow_column->type()->id() != arrow::Type::DICTIONARY) { DataTypePtr nested_type_hint; diff --git a/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp b/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp index 35bf14b7f2bd..d810ed8ace07 100644 --- a/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp +++ b/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp @@ -485,7 +485,10 @@ namespace DB if (column_tuple->tupleSize() == 0) { for (size_t i = start; i != end; ++i) - checkStatus(builder.Append(), column->getName(), format_name); + { + auto status = (null_bytemap && (*null_bytemap)[i]) ? builder.AppendNull() : builder.Append(); + checkStatus(status, column->getName(), format_name); + } return checkResult(builder.Finish(), column_name, format_name); } @@ -494,10 +497,13 @@ namespace DB for (size_t i = 0; i != column_tuple->tupleSize(); ++i) { ColumnPtr nested_column = column_tuple->getColumnPtr(i); + /// Do not propagate the struct-level null_bytemap to child fields. + /// In Arrow, struct-level nulls and child-level nulls are independent; + /// child values at null struct positions are undefined. auto name = column_name + "." + nested_names[i]; std::shared_ptr nested_arrow_array = fillArrowArray( name, - nested_column, nested_types[i], null_bytemap, + nested_column, nested_types[i], nullptr, builder.field_builder(static_cast(i)), format_name, start, end, diff --git a/src/Processors/Formats/Impl/NativeORCBlockInputFormat.cpp b/src/Processors/Formats/Impl/NativeORCBlockInputFormat.cpp index bfb8a02683d7..0ea9046a588d 100644 --- a/src/Processors/Formats/Impl/NativeORCBlockInputFormat.cpp +++ b/src/Processors/Formats/Impl/NativeORCBlockInputFormat.cpp @@ -1710,8 +1710,8 @@ ColumnWithTypeAndName ORCColumnToCHColumn::readColumnFromORCColumn( { bool skipped = false; - if (!inside_nullable && (orc_column->hasNulls || (type_hint && type_hint->isNullable())) && !orc_column->isEncoded - && (orc_type->getKind() != orc::LIST && orc_type->getKind() != orc::MAP && orc_type->getKind() != orc::STRUCT)) + if (!inside_nullable && (orc_column->hasNulls || (type_hint && isNullableOrLowCardinalityNullable(type_hint))) && !orc_column->isEncoded + && (orc_type->getKind() != orc::LIST && orc_type->getKind() != orc::MAP)) { DataTypePtr nested_type_hint; if (type_hint) @@ -1883,19 +1883,27 @@ ColumnWithTypeAndName ORCColumnToCHColumn::readColumnFromORCColumn( auto nested_column = readColumnFromORCColumn(orc_nested_column, orc_nested_type, column_name, false, nested_type_hint); auto offsets_column = readOffsetsFromORCListColumn(orc_list_column); - auto array_column = ColumnArray::create(nested_column.column, offsets_column); DataTypePtr array_type; - /// If type hint is Nested, we should return Nested type, - /// because we differentiate Nested and simple Array(Tuple) - if (type_hint && isNested(type_hint)) + ColumnPtr array_data_column = nested_column.column; + /// If type hint is Nested and the element is a named Tuple, return the Nested type + /// so that `Nested::flatten` can decompose it into separate arrays. + /// When the element is Nullable(Tuple(...)), unwrap it and propagate the struct null + /// map to each element via `unwrapNullableTuple`. + const auto * tuple_type = type_hint && isNested(type_hint) + ? typeid_cast(removeNullable(nested_column.type).get()) + : nullptr; + if (tuple_type) { - const auto & tuple_type = assert_cast(*nested_column.type); - array_type = createNested(tuple_type.getElements(), tuple_type.getElementNames()); + auto unwrapped = Nested::unwrapNullableTuple({array_data_column, nested_column.type, column_name}); + array_data_column = unwrapped.column; + const auto & result_tuple = assert_cast(*unwrapped.type); + array_type = createNested(result_tuple.getElements(), result_tuple.getElementNames()); } else { array_type = std::make_shared(nested_column.type); } + auto array_column = ColumnArray::create(array_data_column, offsets_column); return {array_column, array_type, column_name}; } case orc::STRUCT: diff --git a/src/Processors/Formats/Impl/ORCBlockOutputFormat.cpp b/src/Processors/Formats/Impl/ORCBlockOutputFormat.cpp index 13d6cb6656ad..e5ef21abf1b6 100644 --- a/src/Processors/Formats/Impl/ORCBlockOutputFormat.cpp +++ b/src/Processors/Formats/Impl/ORCBlockOutputFormat.cpp @@ -494,7 +494,27 @@ void ORCBlockOutputFormat::writeColumn( const auto & tuple_column = assert_cast(column); auto nested_types = assert_cast(type.get())->getElements(); for (size_t i = 0; i != tuple_column.tupleSize(); ++i) - writeColumn(*struct_orc_column.fields[i], tuple_column.getColumn(i), nested_types[i], nullptr); + { + if (null_bytemap && nested_types[i]->isNullable()) + { + /// When both the struct and the element are nullable, we need to merge the two null bitmaps: + /// a child value is null if either the struct row is null OR the element itself is null. + const auto & nullable_col = assert_cast(tuple_column.getColumn(i)); + const auto & element_null_map = nullable_col.getNullMapData(); + PaddedPODArray merged_null_map(element_null_map.size()); + for (size_t j = 0; j < element_null_map.size(); ++j) + merged_null_map[j] = element_null_map[j] | (*null_bytemap)[j]; + + auto nested_type = removeNullable(nested_types[i]); + writeColumn(*struct_orc_column.fields[i], nullable_col.getNestedColumn(), nested_type, &merged_null_map); + } + else + { + /// Propagate the struct-level null_bytemap to children so the ORC library correctly handles + /// null struct rows (child values at null positions must also be marked null). + writeColumn(*struct_orc_column.fields[i], tuple_column.getColumn(i), nested_types[i], null_bytemap); + } + } break; } case TypeIndex::Map: diff --git a/tests/queries/0_stateless/02384_nullable_low_cardinality_as_dict_in_arrow.sh b/tests/queries/0_stateless/02384_nullable_low_cardinality_as_dict_in_arrow.sh new file mode 100755 index 000000000000..c54c1831a1b4 --- /dev/null +++ b/tests/queries/0_stateless/02384_nullable_low_cardinality_as_dict_in_arrow.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# Tags: no-fasttest +# no-fasttest: Arrow format is not available in fasttest builds + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +$CLICKHOUSE_LOCAL -q "select toLowCardinality(toNullable('abc')) as lc format Arrow settings output_format_arrow_low_cardinality_as_dictionary=1, output_format_arrow_string_as_string=0" | $CLICKHOUSE_LOCAL --input-format=Arrow --table=test -q "desc test" +$CLICKHOUSE_LOCAL -q "select toLowCardinality(toNullable('abc')) as lc format Arrow settings output_format_arrow_low_cardinality_as_dictionary=1, output_format_arrow_string_as_string=0" | $CLICKHOUSE_LOCAL --input-format=Arrow --table=test -q "select * from test" +$CLICKHOUSE_LOCAL -q "select toLowCardinality(toNullable('abc')) as lc format Arrow settings output_format_arrow_low_cardinality_as_dictionary=1, output_format_arrow_string_as_string=1" | $CLICKHOUSE_LOCAL --input-format=Arrow --table=test -q "desc test" +$CLICKHOUSE_LOCAL -q "select toLowCardinality(toNullable('abc')) as lc format Arrow settings output_format_arrow_low_cardinality_as_dictionary=1, output_format_arrow_string_as_string=1" | $CLICKHOUSE_LOCAL --input-format=Arrow --table=test -q "select * from test" diff --git a/tests/queries/0_stateless/02384_nullable_low_cardinality_as_dict_in_arrow.sql b/tests/queries/0_stateless/02384_nullable_low_cardinality_as_dict_in_arrow.sql deleted file mode 100644 index 975e7fb88267..000000000000 --- a/tests/queries/0_stateless/02384_nullable_low_cardinality_as_dict_in_arrow.sql +++ /dev/null @@ -1,8 +0,0 @@ --- Tags: no-fasttest - -insert into function file(02384_data.arrow) select toLowCardinality(toNullable('abc')) as lc settings output_format_arrow_low_cardinality_as_dictionary=1, output_format_arrow_string_as_string=0, engine_file_truncate_on_insert=1; -desc file(02384_data.arrow); -select * from file(02384_data.arrow); -insert into function file(02384_data.arrow) select toLowCardinality(toNullable('abc')) as lc settings output_format_arrow_low_cardinality_as_dictionary=1, output_format_arrow_string_as_string=1, engine_file_truncate_on_insert=1; -desc file(02384_data.arrow); -select * from file(02384_data.arrow); diff --git a/tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.reference b/tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.reference new file mode 100644 index 000000000000..2f627fd884b2 --- /dev/null +++ b/tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.reference @@ -0,0 +1,186 @@ +-- { echo } + +SET allow_experimental_nullable_tuple_type = 1; +SET engine_file_truncate_on_insert = 1; +DROP TABLE IF EXISTS test_nullable_empty_tuple; +CREATE TABLE test_nullable_empty_tuple (c0 Nullable(Tuple())) ENGINE = Memory; +INSERT INTO TABLE test_nullable_empty_tuple (c0) VALUES (()), (NULL), (()); +SELECT 'CSV'; +CSV +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.csv', 'CSV', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.csv', 'CSV', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'TabSeparated'; +TabSeparated +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.tsv', 'TabSeparated', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.tsv', 'TabSeparated', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'TabSeparatedRaw'; +TabSeparatedRaw +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.tsvraw', 'TabSeparatedRaw', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.tsvraw', 'TabSeparatedRaw', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONEachRow'; +JSONEachRow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.json', 'JSONEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.json', 'JSONEachRow', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONCompactEachRow'; +JSONCompactEachRow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncompact', 'JSONCompactEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncompact', 'JSONCompactEachRow', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONStringsEachRow'; +JSONStringsEachRow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsonstr', 'JSONStringsEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsonstr', 'JSONStringsEachRow', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONCompactStringsEachRow'; +JSONCompactStringsEachRow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncstr', 'JSONCompactStringsEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncstr', 'JSONCompactStringsEachRow', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'Values'; +Values +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.values', 'Values', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.values', 'Values', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'TSKV'; +TSKV +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.tskv', 'TSKV', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.tskv', 'TSKV', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'CustomSeparated'; +CustomSeparated +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.custom', 'CustomSeparated', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.custom', 'CustomSeparated', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'Native'; +Native +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.native', 'Native', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.native', 'Native', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'RowBinary'; +RowBinary +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.rowbin', 'RowBinary', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.rowbin', 'RowBinary', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'Avro'; +Avro +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.avro', 'Avro', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.avro', 'Avro', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'MsgPack'; +MsgPack +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.msgpack', 'MsgPack', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.msgpack', 'MsgPack', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'BSONEachRow'; +BSONEachRow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.bson', 'BSONEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.bson', 'BSONEachRow', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSON'; +JSON +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsonall', 'JSON', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsonall', 'JSON', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONCompact'; +JSONCompact +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncompactall', 'JSONCompact', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncompactall', 'JSONCompact', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONColumns'; +JSONColumns +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncols', 'JSONColumns', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncols', 'JSONColumns', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONCompactColumns'; +JSONCompactColumns +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsonccols', 'JSONCompactColumns', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsonccols', 'JSONCompactColumns', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONColumnsWithMetadata'; +JSONColumnsWithMetadata +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncolsmeta', 'JSONColumnsWithMetadata', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncolsmeta', 'JSONColumnsWithMetadata', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'JSONObjectEachRow'; +JSONObjectEachRow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsonobj', 'JSONObjectEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsonobj', 'JSONObjectEachRow', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'Buffers'; +Buffers +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.buf', 'Buffers', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.buf', 'Buffers', 'c0 Nullable(Tuple())'); +() +\N +() +-- Parquet doesn't support empty tuples by design +SELECT 'Parquet'; +Parquet +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.parquet', 'Parquet', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; -- { serverError BAD_ARGUMENTS } +SELECT 'Arrow'; +Arrow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.arrow', 'Arrow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.arrow', 'Arrow', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'ArrowStream'; +ArrowStream +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple())'); +() +\N +() +SELECT 'ORC'; +ORC +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.orc', 'ORC', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.orc', 'ORC', 'c0 Nullable(Tuple())'); +() +\N +() diff --git a/tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.sql b/tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.sql new file mode 100644 index 000000000000..ee194259f764 --- /dev/null +++ b/tests/queries/0_stateless/04019_formats_nullable_empty_tuple_roundtrip.sql @@ -0,0 +1,115 @@ +-- Tags: no-fasttest +-- no-fasttest: Some formats not available in fasttest enviroment + +-- { echo } + +SET allow_experimental_nullable_tuple_type = 1; +SET engine_file_truncate_on_insert = 1; + +DROP TABLE IF EXISTS test_nullable_empty_tuple; +CREATE TABLE test_nullable_empty_tuple (c0 Nullable(Tuple())) ENGINE = Memory; +INSERT INTO TABLE test_nullable_empty_tuple (c0) VALUES (()), (NULL), (()); + +SELECT 'CSV'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.csv', 'CSV', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.csv', 'CSV', 'c0 Nullable(Tuple())'); + +SELECT 'TabSeparated'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.tsv', 'TabSeparated', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.tsv', 'TabSeparated', 'c0 Nullable(Tuple())'); + +SELECT 'TabSeparatedRaw'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.tsvraw', 'TabSeparatedRaw', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.tsvraw', 'TabSeparatedRaw', 'c0 Nullable(Tuple())'); + +SELECT 'JSONEachRow'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.json', 'JSONEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.json', 'JSONEachRow', 'c0 Nullable(Tuple())'); + +SELECT 'JSONCompactEachRow'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncompact', 'JSONCompactEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncompact', 'JSONCompactEachRow', 'c0 Nullable(Tuple())'); + +SELECT 'JSONStringsEachRow'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsonstr', 'JSONStringsEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsonstr', 'JSONStringsEachRow', 'c0 Nullable(Tuple())'); + +SELECT 'JSONCompactStringsEachRow'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncstr', 'JSONCompactStringsEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncstr', 'JSONCompactStringsEachRow', 'c0 Nullable(Tuple())'); + +SELECT 'Values'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.values', 'Values', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.values', 'Values', 'c0 Nullable(Tuple())'); + +SELECT 'TSKV'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.tskv', 'TSKV', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.tskv', 'TSKV', 'c0 Nullable(Tuple())'); + +SELECT 'CustomSeparated'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.custom', 'CustomSeparated', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.custom', 'CustomSeparated', 'c0 Nullable(Tuple())'); + +SELECT 'Native'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.native', 'Native', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.native', 'Native', 'c0 Nullable(Tuple())'); + +SELECT 'RowBinary'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.rowbin', 'RowBinary', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.rowbin', 'RowBinary', 'c0 Nullable(Tuple())'); + +SELECT 'Avro'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.avro', 'Avro', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.avro', 'Avro', 'c0 Nullable(Tuple())'); + +SELECT 'MsgPack'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.msgpack', 'MsgPack', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.msgpack', 'MsgPack', 'c0 Nullable(Tuple())'); + +SELECT 'BSONEachRow'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.bson', 'BSONEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.bson', 'BSONEachRow', 'c0 Nullable(Tuple())'); + +SELECT 'JSON'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsonall', 'JSON', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsonall', 'JSON', 'c0 Nullable(Tuple())'); + +SELECT 'JSONCompact'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncompactall', 'JSONCompact', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncompactall', 'JSONCompact', 'c0 Nullable(Tuple())'); + +SELECT 'JSONColumns'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncols', 'JSONColumns', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncols', 'JSONColumns', 'c0 Nullable(Tuple())'); + +SELECT 'JSONCompactColumns'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsonccols', 'JSONCompactColumns', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsonccols', 'JSONCompactColumns', 'c0 Nullable(Tuple())'); + +SELECT 'JSONColumnsWithMetadata'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsoncolsmeta', 'JSONColumnsWithMetadata', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsoncolsmeta', 'JSONColumnsWithMetadata', 'c0 Nullable(Tuple())'); + +SELECT 'JSONObjectEachRow'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.jsonobj', 'JSONObjectEachRow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.jsonobj', 'JSONObjectEachRow', 'c0 Nullable(Tuple())'); + +SELECT 'Buffers'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.buf', 'Buffers', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.buf', 'Buffers', 'c0 Nullable(Tuple())'); + +-- Parquet doesn't support empty tuples by design +SELECT 'Parquet'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.parquet', 'Parquet', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; -- { serverError BAD_ARGUMENTS } + +SELECT 'Arrow'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.arrow', 'Arrow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.arrow', 'Arrow', 'c0 Nullable(Tuple())'); + +SELECT 'ArrowStream'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple())'); + +SELECT 'ORC'; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04019.orc', 'ORC', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_empty_tuple; +SELECT c0 FROM file(currentDatabase() || '_04019.orc', 'ORC', 'c0 Nullable(Tuple())'); diff --git a/tests/queries/0_stateless/04064_tuple_inside_nullable_arrow_orc_roundtrip.reference b/tests/queries/0_stateless/04064_tuple_inside_nullable_arrow_orc_roundtrip.reference new file mode 100644 index 000000000000..5ade702466c0 --- /dev/null +++ b/tests/queries/0_stateless/04064_tuple_inside_nullable_arrow_orc_roundtrip.reference @@ -0,0 +1,456 @@ +-- { echo } + +SET allow_experimental_nullable_tuple_type = 1; +SET engine_file_truncate_on_insert = 1; +-- Nullable struct with non-nullable elements +DROP TABLE IF EXISTS test_nullable_tuple_basic; +CREATE TABLE test_nullable_tuple_basic (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_basic VALUES ((1, 'a')), (NULL), ((3, 'c')); +-- Arrow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_basic; +SELECT c0 FROM file(currentDatabase() || '_04064.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))'); +(1,'a') +\N +(3,'c') +-- ArrowStream +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_basic; +SELECT c0 FROM file(currentDatabase() || '_04064.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple(UInt32, String))'); +(1,'a') +\N +(3,'c') +-- ORC +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_basic; +SELECT c0 FROM file(currentDatabase() || '_04064.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))'); +(1,'a') +\N +(3,'c') +-- ORC legacy (Arrow-based) reader +SELECT c0 FROM file(currentDatabase() || '_04064.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') +\N +(3,'c') +DROP TABLE test_nullable_tuple_basic; +-- Nullable empty tuple +DROP TABLE IF EXISTS test_nullable_tuple_empty; +CREATE TABLE test_nullable_tuple_empty (c0 Nullable(Tuple())) ENGINE = Memory; +INSERT INTO test_nullable_tuple_empty VALUES (()), (NULL), (()); +-- Arrow empty +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_empty.arrow', 'Arrow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_tuple_empty; +SELECT c0 FROM file(currentDatabase() || '_04064_empty.arrow', 'Arrow', 'c0 Nullable(Tuple())'); +() +\N +() +-- ArrowStream empty +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_empty.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_tuple_empty; +SELECT c0 FROM file(currentDatabase() || '_04064_empty.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple())'); +() +\N +() +-- ORC empty +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_empty.orc', 'ORC', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_tuple_empty; +SELECT c0 FROM file(currentDatabase() || '_04064_empty.orc', 'ORC', 'c0 Nullable(Tuple())'); +() +\N +() +-- ORC legacy empty +SELECT c0 FROM file(currentDatabase() || '_04064_empty.orc', 'ORC', 'c0 Nullable(Tuple())') SETTINGS input_format_orc_use_fast_decoder = 0; +() +\N +() +DROP TABLE test_nullable_tuple_empty; +-- Both struct and element nullable: Nullable(Tuple(Nullable(UInt32), String)) +DROP TABLE IF EXISTS test_nullable_tuple_both; +CREATE TABLE test_nullable_tuple_both (c0 Nullable(Tuple(Nullable(UInt32), String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_both VALUES ((1, 'a')), (NULL), ((NULL, 'c')), ((4, 'd')); +-- Arrow both nullable +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_both.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_both; +SELECT c0 FROM file(currentDatabase() || '_04064_both.arrow', 'Arrow', 'c0 Nullable(Tuple(Nullable(UInt32), String))'); +(1,'a') +\N +(NULL,'c') +(4,'d') +-- ArrowStream both nullable +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_both.arrowstream', 'ArrowStream') SELECT c0 FROM test_nullable_tuple_both; +SELECT c0 FROM file(currentDatabase() || '_04064_both.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple(Nullable(UInt32), String))'); +(1,'a') +\N +(NULL,'c') +(4,'d') +-- ORC both nullable +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_both.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_both; +SELECT c0 FROM file(currentDatabase() || '_04064_both.orc', 'ORC', 'c0 Nullable(Tuple(Nullable(UInt32), String))'); +(1,'a') +\N +(NULL,'c') +(4,'d') +-- ORC legacy both nullable +SELECT c0 FROM file(currentDatabase() || '_04064_both.orc', 'ORC', 'c0 Nullable(Tuple(Nullable(UInt32), String))') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') +\N +(NULL,'c') +(4,'d') +DROP TABLE test_nullable_tuple_both; +-- Non-nullable struct with nullable elements (should be unchanged) +DROP TABLE IF EXISTS test_nullable_tuple_elem; +CREATE TABLE test_nullable_tuple_elem (c0 Tuple(Nullable(UInt32), String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_elem VALUES ((1, 'a')), ((NULL, 'b')); +-- Arrow nullable elements +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_elem.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_elem; +SELECT c0 FROM file(currentDatabase() || '_04064_elem.arrow', 'Arrow', 'c0 Tuple(Nullable(UInt32), String)'); +(1,'a') +(NULL,'b') +-- ORC nullable elements +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_elem.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_elem; +SELECT c0 FROM file(currentDatabase() || '_04064_elem.orc', 'ORC', 'c0 Tuple(Nullable(UInt32), String)'); +(1,'a') +(NULL,'b') +-- ORC legacy nullable elements +SELECT c0 FROM file(currentDatabase() || '_04064_elem.orc', 'ORC', 'c0 Tuple(Nullable(UInt32), String)') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') +(NULL,'b') +DROP TABLE test_nullable_tuple_elem; +-- Plain non-nullable tuple (baseline, should be unchanged) +DROP TABLE IF EXISTS test_nullable_tuple_plain; +CREATE TABLE test_nullable_tuple_plain (c0 Tuple(UInt32, String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_plain VALUES ((1, 'a')), ((2, 'b')); +-- Arrow plain +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_plain.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_plain; +SELECT c0 FROM file(currentDatabase() || '_04064_plain.arrow', 'Arrow', 'c0 Tuple(UInt32, String)'); +(1,'a') +(2,'b') +-- ORC plain +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_plain.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_plain; +SELECT c0 FROM file(currentDatabase() || '_04064_plain.orc', 'ORC', 'c0 Tuple(UInt32, String)'); +(1,'a') +(2,'b') +-- ORC legacy plain +SELECT c0 FROM file(currentDatabase() || '_04064_plain.orc', 'ORC', 'c0 Tuple(UInt32, String)') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') +(2,'b') +DROP TABLE test_nullable_tuple_plain; +-- Nested tuple inside nullable struct +DROP TABLE IF EXISTS test_nullable_tuple_nested; +CREATE TABLE test_nullable_tuple_nested (c0 Nullable(Tuple(Tuple(UInt32, String), UInt64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_nested VALUES (((1, 'a'), 10)), (NULL), (((3, 'c'), 30)); +-- Arrow nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_nested.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_nested; +SELECT c0 FROM file(currentDatabase() || '_04064_nested.arrow', 'Arrow', 'c0 Nullable(Tuple(Tuple(UInt32, String), UInt64))'); +((1,'a'),10) +\N +((3,'c'),30) +-- ORC nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_nested.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_nested; +SELECT c0 FROM file(currentDatabase() || '_04064_nested.orc', 'ORC', 'c0 Nullable(Tuple(Tuple(UInt32, String), UInt64))'); +((1,'a'),10) +\N +((3,'c'),30) +-- ORC legacy nested +SELECT c0 FROM file(currentDatabase() || '_04064_nested.orc', 'ORC', 'c0 Nullable(Tuple(Tuple(UInt32, String), UInt64))') SETTINGS input_format_orc_use_fast_decoder = 0; +((1,'a'),10) +\N +((3,'c'),30) +DROP TABLE test_nullable_tuple_nested; +-- Schema inference without type hint +DROP TABLE IF EXISTS test_nullable_tuple_infer; +CREATE TABLE test_nullable_tuple_infer (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_infer VALUES ((1, 'a')), (NULL), ((3, 'c')); +-- Arrow infer +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_infer.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_infer; +SELECT c0 FROM file(currentDatabase() || '_04064_infer.arrow', 'Arrow'); +(1,'a') +\N +(3,'c') +-- ORC infer +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_infer.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_infer; +SELECT c0 FROM file(currentDatabase() || '_04064_infer.orc', 'ORC'); +(1,'a') +\N +(3,'c') +-- ORC legacy infer +SELECT c0 FROM file(currentDatabase() || '_04064_infer.orc', 'ORC') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') +\N +(3,'c') +DROP TABLE test_nullable_tuple_infer; +-- Named tuple +DROP TABLE IF EXISTS test_nullable_tuple_named; +CREATE TABLE test_nullable_tuple_named (c0 Nullable(Tuple(a UInt32, b String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_named VALUES ((1, 'x')), (NULL), ((3, 'z')); +-- Arrow named +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_named.arrow', 'Arrow', 'c0 Nullable(Tuple(a UInt32, b String))') SELECT c0 FROM test_nullable_tuple_named; +SELECT c0 FROM file(currentDatabase() || '_04064_named.arrow', 'Arrow', 'c0 Nullable(Tuple(a UInt32, b String))'); +(1,'x') +\N +(3,'z') +-- ORC named +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_named.orc', 'ORC', 'c0 Nullable(Tuple(a UInt32, b String))') SELECT c0 FROM test_nullable_tuple_named; +SELECT c0 FROM file(currentDatabase() || '_04064_named.orc', 'ORC', 'c0 Nullable(Tuple(a UInt32, b String))'); +(1,'x') +\N +(3,'z') +-- ORC legacy named +SELECT c0 FROM file(currentDatabase() || '_04064_named.orc', 'ORC', 'c0 Nullable(Tuple(a UInt32, b String))') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'x') +\N +(3,'z') +DROP TABLE test_nullable_tuple_named; +-- All-NULL column +DROP TABLE IF EXISTS test_nullable_tuple_allnull; +CREATE TABLE test_nullable_tuple_allnull (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_allnull VALUES (NULL), (NULL), (NULL); +-- Arrow all null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_allnull.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_allnull; +SELECT c0 FROM file(currentDatabase() || '_04064_allnull.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))'); +\N +\N +\N +-- ORC all null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_allnull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_allnull; +SELECT c0 FROM file(currentDatabase() || '_04064_allnull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))'); +\N +\N +\N +-- ORC legacy all null +SELECT c0 FROM file(currentDatabase() || '_04064_allnull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_orc_use_fast_decoder = 0; +\N +\N +\N +DROP TABLE test_nullable_tuple_allnull; +-- No-NULL column (nullable type, zero actual NULLs) +DROP TABLE IF EXISTS test_nullable_tuple_nonull; +CREATE TABLE test_nullable_tuple_nonull (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_nonull VALUES ((1, 'a')), ((2, 'b')), ((3, 'c')); +-- Arrow no null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_nonull.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_nonull; +SELECT c0 FROM file(currentDatabase() || '_04064_nonull.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))'); +(1,'a') +(2,'b') +(3,'c') +-- ORC no null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_nonull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_nonull; +SELECT c0 FROM file(currentDatabase() || '_04064_nonull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))'); +(1,'a') +(2,'b') +(3,'c') +-- ORC legacy no null +SELECT c0 FROM file(currentDatabase() || '_04064_nonull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') +(2,'b') +(3,'c') +DROP TABLE test_nullable_tuple_nonull; +-- Single-element tuple +DROP TABLE IF EXISTS test_nullable_tuple_single; +CREATE TABLE test_nullable_tuple_single (c0 Nullable(Tuple(UInt32))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_single VALUES ((1,)), (NULL), ((3,)); +-- Arrow single +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_single.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32))') SELECT c0 FROM test_nullable_tuple_single; +SELECT c0 FROM file(currentDatabase() || '_04064_single.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32))'); +(1) +\N +(3) +-- ORC single +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_single.orc', 'ORC', 'c0 Nullable(Tuple(UInt32))') SELECT c0 FROM test_nullable_tuple_single; +SELECT c0 FROM file(currentDatabase() || '_04064_single.orc', 'ORC', 'c0 Nullable(Tuple(UInt32))'); +(1) +\N +(3) +-- ORC legacy single +SELECT c0 FROM file(currentDatabase() || '_04064_single.orc', 'ORC', 'c0 Nullable(Tuple(UInt32))') SETTINGS input_format_orc_use_fast_decoder = 0; +(1) +\N +(3) +DROP TABLE test_nullable_tuple_single; +-- Deeply nested: nullable tuple inside nullable tuple +DROP TABLE IF EXISTS test_nullable_tuple_deep; +CREATE TABLE test_nullable_tuple_deep (c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_deep VALUES (((1, 'a'), 10)), (NULL), ((NULL, 20)), (((4, 'd'), 40)); +-- Arrow deep nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_deep.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_deep; +SELECT c0 FROM file(currentDatabase() || '_04064_deep.arrow', 'Arrow', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))'); +((1,'a'),10) +\N +(NULL,20) +((4,'d'),40) +-- ORC deep nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_deep.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_deep; +SELECT c0 FROM file(currentDatabase() || '_04064_deep.orc', 'ORC', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))'); +((1,'a'),10) +\N +(NULL,20) +((4,'d'),40) +-- ORC legacy deep nested +SELECT c0 FROM file(currentDatabase() || '_04064_deep.orc', 'ORC', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))') SETTINGS input_format_orc_use_fast_decoder = 0; +((1,'a'),10) +\N +(NULL,20) +((4,'d'),40) +DROP TABLE test_nullable_tuple_deep; +-- Nullable tuple with Array element +DROP TABLE IF EXISTS test_nullable_tuple_arr; +CREATE TABLE test_nullable_tuple_arr (c0 Nullable(Tuple(Array(UInt32), String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr VALUES (([1, 2], 'a')), (NULL), (([3], 'c')); +-- Arrow array elem +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_arr; +SELECT c0 FROM file(currentDatabase() || '_04064_arr.arrow', 'Arrow', 'c0 Nullable(Tuple(Array(UInt32), String))'); +([1,2],'a') +\N +([3],'c') +-- ORC array elem +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_arr; +SELECT c0 FROM file(currentDatabase() || '_04064_arr.orc', 'ORC', 'c0 Nullable(Tuple(Array(UInt32), String))'); +([1,2],'a') +\N +([3],'c') +-- ORC legacy array elem +SELECT c0 FROM file(currentDatabase() || '_04064_arr.orc', 'ORC', 'c0 Nullable(Tuple(Array(UInt32), String))') SETTINGS input_format_orc_use_fast_decoder = 0; +([1,2],'a') +\N +([3],'c') +DROP TABLE test_nullable_tuple_arr; +-- Multiple nullable tuple columns +DROP TABLE IF EXISTS test_nullable_tuple_multi; +CREATE TABLE test_nullable_tuple_multi (c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_multi VALUES ((1, 'a'), (1.5)), (NULL, (2.5)), ((3, 'c'), NULL); +-- Arrow multi col +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_multi.arrow', 'Arrow') SELECT c0, c1 FROM test_nullable_tuple_multi; +SELECT c0, c1 FROM file(currentDatabase() || '_04064_multi.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))'); +(1,'a') (1.5) +\N (2.5) +(3,'c') \N +-- ORC multi col +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_multi.orc', 'ORC') SELECT c0, c1 FROM test_nullable_tuple_multi; +SELECT c0, c1 FROM file(currentDatabase() || '_04064_multi.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))'); +(1,'a') (1.5) +\N (2.5) +(3,'c') \N +-- ORC legacy multi col +SELECT c0, c1 FROM file(currentDatabase() || '_04064_multi.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') (1.5) +\N (2.5) +(3,'c') \N +DROP TABLE test_nullable_tuple_multi; +-- Type hint mismatch: file has Nullable(Tuple(...)), read as Tuple(...) (strip nullable, NULLs become defaults) +DROP TABLE IF EXISTS test_nullable_tuple_mismatch1; +CREATE TABLE test_nullable_tuple_mismatch1 (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_mismatch1 VALUES ((1, 'a')), (NULL), ((3, 'c')); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_mismatch1.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_mismatch1; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_mismatch1.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_mismatch1; +-- Arrow: read nullable file as non-nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch1.arrow', 'Arrow', 'c0 Tuple(UInt32, String)'); +(1,'a') Tuple(UInt32, String) +(0,'') Tuple(UInt32, String) +(3,'c') Tuple(UInt32, String) +-- ORC: read nullable file as non-nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch1.orc', 'ORC', 'c0 Tuple(UInt32, String)'); +(1,'a') Tuple(UInt32, String) +(0,'') Tuple(UInt32, String) +(3,'c') Tuple(UInt32, String) +-- ORC legacy: read nullable file as non-nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch1.orc', 'ORC', 'c0 Tuple(UInt32, String)') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') Tuple(UInt32, String) +(0,'') Tuple(UInt32, String) +(3,'c') Tuple(UInt32, String) +DROP TABLE test_nullable_tuple_mismatch1; +-- Type hint mismatch: file has Tuple(...), read as Nullable(Tuple(...)) (add nullable wrapper) +DROP TABLE IF EXISTS test_nullable_tuple_mismatch2; +CREATE TABLE test_nullable_tuple_mismatch2 (c0 Tuple(UInt32, String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_mismatch2 VALUES ((1, 'a')), ((2, 'b')); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_mismatch2.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_mismatch2; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_mismatch2.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_mismatch2; +-- Arrow: read non-nullable file as nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch2.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))'); +(1,'a') Nullable(Tuple(UInt32, String)) +(2,'b') Nullable(Tuple(UInt32, String)) +-- ORC: read non-nullable file as nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch2.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))'); +(1,'a') Nullable(Tuple(UInt32, String)) +(2,'b') Nullable(Tuple(UInt32, String)) +-- ORC legacy: read non-nullable file as nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch2.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_orc_use_fast_decoder = 0; +(1,'a') Nullable(Tuple(UInt32, String)) +(2,'b') Nullable(Tuple(UInt32, String)) +DROP TABLE test_nullable_tuple_mismatch2; +-- Schema inference: DESCRIBE without type hint shows inferred type +DROP TABLE IF EXISTS test_nullable_tuple_describe; +CREATE TABLE test_nullable_tuple_describe (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_describe VALUES ((1, 'a')), (NULL), ((3, 'c')); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_describe.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_describe; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_describe.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_describe; +-- Arrow: inferred type +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_describe.arrow', 'Arrow'); +(1,'a') Nullable(Tuple(`1` UInt32, `2` String)) +\N Nullable(Tuple(`1` UInt32, `2` String)) +(3,'c') Nullable(Tuple(`1` UInt32, `2` String)) +-- ORC: inferred type +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_describe.orc', 'ORC'); +(1,'a') Nullable(Tuple(`1` Nullable(Int32), `2` Nullable(String))) +\N Nullable(Tuple(`1` Nullable(Int32), `2` Nullable(String))) +(3,'c') Nullable(Tuple(`1` Nullable(Int32), `2` Nullable(String))) +DROP TABLE test_nullable_tuple_describe; +-- Array(Nullable(Tuple)) flattened via import_nested: struct-level NULLs should propagate to elements +DROP TABLE IF EXISTS test_nullable_tuple_import_nested; +CREATE TABLE test_nullable_tuple_import_nested (c0 Array(Nullable(Tuple(a UInt32, b String)))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_import_nested VALUES ([(1, 'a'), NULL, (3, 'c')]); +-- Arrow import_nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_import_nested.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_import_nested; +SELECT * FROM file(currentDatabase() || '_04064_import_nested.arrow', 'Arrow', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_arrow_import_nested = 1; +[1,NULL,3] ['a',NULL,'c'] +-- ORC import_nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_import_nested.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_import_nested; +SELECT * FROM file(currentDatabase() || '_04064_import_nested.orc', 'ORC', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_orc_import_nested = 1; +[1,NULL,3] ['a',NULL,'c'] +-- ORC legacy import_nested +SELECT * FROM file(currentDatabase() || '_04064_import_nested.orc', 'ORC', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_orc_import_nested = 1, input_format_orc_use_fast_decoder = 0; +[1,NULL,3] ['a',NULL,'c'] +DROP TABLE test_nullable_tuple_import_nested; +-- Array(Nullable(Tuple)) without named elements: round-trip as a single column, no flattening +DROP TABLE IF EXISTS test_nullable_tuple_arr_unnamed; +CREATE TABLE test_nullable_tuple_arr_unnamed (c0 Array(Nullable(Tuple(UInt32, String)))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr_unnamed VALUES ([(1, 'a'), NULL, (3, 'c')]); +-- Arrow unnamed +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr_unnamed.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_arr_unnamed; +SELECT c0 FROM file(currentDatabase() || '_04064_arr_unnamed.arrow', 'Arrow', 'c0 Array(Nullable(Tuple(UInt32, String)))'); +[(1,'a'),NULL,(3,'c')] +-- ORC unnamed +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr_unnamed.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_arr_unnamed; +SELECT c0 FROM file(currentDatabase() || '_04064_arr_unnamed.orc', 'ORC', 'c0 Array(Nullable(Tuple(UInt32, String)))'); +[(1,'a'),NULL,(3,'c')] +-- ORC legacy unnamed +SELECT c0 FROM file(currentDatabase() || '_04064_arr_unnamed.orc', 'ORC', 'c0 Array(Nullable(Tuple(UInt32, String)))') SETTINGS input_format_orc_use_fast_decoder = 0; +[(1,'a'),NULL,(3,'c')] +DROP TABLE test_nullable_tuple_arr_unnamed; +-- Array(Nullable(Tuple)) with Array element inside: import_nested flattens, Array defaults to [] at null positions +DROP TABLE IF EXISTS test_nullable_tuple_arr_nested_elem; +CREATE TABLE test_nullable_tuple_arr_nested_elem (c0 Array(Nullable(Tuple(a UInt32, b Array(UInt32))))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr_nested_elem VALUES ([(1, [10, 20]), NULL, (3, [30])]); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr_nested_elem.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_arr_nested_elem; +-- Arrow import_nested: scalar becomes Nullable, Array defaults to [] at null struct positions +SELECT * FROM file(currentDatabase() || '_04064_arr_nested_elem.arrow', 'Arrow', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_arrow_import_nested = 1; +[1,NULL,3] [[10,20],[],[30]] +-- ORC import_nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr_nested_elem.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_arr_nested_elem; +SELECT * FROM file(currentDatabase() || '_04064_arr_nested_elem.orc', 'ORC', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_orc_import_nested = 1; +[1,NULL,3] [[10,20],[],[30]] +-- ORC legacy import_nested +SELECT * FROM file(currentDatabase() || '_04064_arr_nested_elem.orc', 'ORC', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_orc_import_nested = 1, input_format_orc_use_fast_decoder = 0; +[1,NULL,3] [[10,20],[],[30]] +DROP TABLE test_nullable_tuple_arr_nested_elem; +-- LowCardinality(Nullable(String)) hint with no physical nulls in the file: the ORC reader must still wrap the column as nullable +DROP TABLE IF EXISTS test_nullable_tuple_lc_string; +CREATE TABLE test_nullable_tuple_lc_string (c0 String) ENGINE = Memory; +INSERT INTO test_nullable_tuple_lc_string VALUES ('hello'), ('world'); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_lc_str.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_lc_string; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_lc_str.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_lc_string; +-- Arrow: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_lc_str.arrow', 'Arrow', 'c0 LowCardinality(Nullable(String))'); +hello LowCardinality(Nullable(String)) +world LowCardinality(Nullable(String)) +-- ORC: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_lc_str.orc', 'ORC', 'c0 LowCardinality(Nullable(String))'); +hello LowCardinality(Nullable(String)) +world LowCardinality(Nullable(String)) +-- ORC legacy: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_lc_str.orc', 'ORC', 'c0 LowCardinality(Nullable(String))') SETTINGS input_format_orc_use_fast_decoder = 0; +hello LowCardinality(Nullable(String)) +world LowCardinality(Nullable(String)) +DROP TABLE test_nullable_tuple_lc_string; diff --git a/tests/queries/0_stateless/04064_tuple_inside_nullable_arrow_orc_roundtrip.sql b/tests/queries/0_stateless/04064_tuple_inside_nullable_arrow_orc_roundtrip.sql new file mode 100644 index 000000000000..694c075b68e9 --- /dev/null +++ b/tests/queries/0_stateless/04064_tuple_inside_nullable_arrow_orc_roundtrip.sql @@ -0,0 +1,399 @@ +-- Tags: no-fasttest +-- no-fasttest: Arrow and ORC formats are not available in fasttest builds + +-- { echo } + +SET allow_experimental_nullable_tuple_type = 1; +SET engine_file_truncate_on_insert = 1; + +-- Nullable struct with non-nullable elements +DROP TABLE IF EXISTS test_nullable_tuple_basic; +CREATE TABLE test_nullable_tuple_basic (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_basic VALUES ((1, 'a')), (NULL), ((3, 'c')); + +-- Arrow +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_basic; +SELECT c0 FROM file(currentDatabase() || '_04064.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ArrowStream +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_basic; +SELECT c0 FROM file(currentDatabase() || '_04064.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ORC +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_basic; +SELECT c0 FROM file(currentDatabase() || '_04064.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ORC legacy (Arrow-based) reader +SELECT c0 FROM file(currentDatabase() || '_04064.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_basic; + +-- Nullable empty tuple +DROP TABLE IF EXISTS test_nullable_tuple_empty; +CREATE TABLE test_nullable_tuple_empty (c0 Nullable(Tuple())) ENGINE = Memory; +INSERT INTO test_nullable_tuple_empty VALUES (()), (NULL), (()); + +-- Arrow empty +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_empty.arrow', 'Arrow', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_tuple_empty; +SELECT c0 FROM file(currentDatabase() || '_04064_empty.arrow', 'Arrow', 'c0 Nullable(Tuple())'); + +-- ArrowStream empty +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_empty.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_tuple_empty; +SELECT c0 FROM file(currentDatabase() || '_04064_empty.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple())'); + +-- ORC empty +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_empty.orc', 'ORC', 'c0 Nullable(Tuple())') SELECT c0 FROM test_nullable_tuple_empty; +SELECT c0 FROM file(currentDatabase() || '_04064_empty.orc', 'ORC', 'c0 Nullable(Tuple())'); + +-- ORC legacy empty +SELECT c0 FROM file(currentDatabase() || '_04064_empty.orc', 'ORC', 'c0 Nullable(Tuple())') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_empty; + +-- Both struct and element nullable: Nullable(Tuple(Nullable(UInt32), String)) +DROP TABLE IF EXISTS test_nullable_tuple_both; +CREATE TABLE test_nullable_tuple_both (c0 Nullable(Tuple(Nullable(UInt32), String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_both VALUES ((1, 'a')), (NULL), ((NULL, 'c')), ((4, 'd')); + +-- Arrow both nullable +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_both.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_both; +SELECT c0 FROM file(currentDatabase() || '_04064_both.arrow', 'Arrow', 'c0 Nullable(Tuple(Nullable(UInt32), String))'); + +-- ArrowStream both nullable +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_both.arrowstream', 'ArrowStream') SELECT c0 FROM test_nullable_tuple_both; +SELECT c0 FROM file(currentDatabase() || '_04064_both.arrowstream', 'ArrowStream', 'c0 Nullable(Tuple(Nullable(UInt32), String))'); + +-- ORC both nullable +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_both.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_both; +SELECT c0 FROM file(currentDatabase() || '_04064_both.orc', 'ORC', 'c0 Nullable(Tuple(Nullable(UInt32), String))'); + +-- ORC legacy both nullable +SELECT c0 FROM file(currentDatabase() || '_04064_both.orc', 'ORC', 'c0 Nullable(Tuple(Nullable(UInt32), String))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_both; + +-- Non-nullable struct with nullable elements (should be unchanged) +DROP TABLE IF EXISTS test_nullable_tuple_elem; +CREATE TABLE test_nullable_tuple_elem (c0 Tuple(Nullable(UInt32), String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_elem VALUES ((1, 'a')), ((NULL, 'b')); + +-- Arrow nullable elements +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_elem.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_elem; +SELECT c0 FROM file(currentDatabase() || '_04064_elem.arrow', 'Arrow', 'c0 Tuple(Nullable(UInt32), String)'); + +-- ORC nullable elements +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_elem.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_elem; +SELECT c0 FROM file(currentDatabase() || '_04064_elem.orc', 'ORC', 'c0 Tuple(Nullable(UInt32), String)'); + +-- ORC legacy nullable elements +SELECT c0 FROM file(currentDatabase() || '_04064_elem.orc', 'ORC', 'c0 Tuple(Nullable(UInt32), String)') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_elem; + +-- Plain non-nullable tuple (baseline, should be unchanged) +DROP TABLE IF EXISTS test_nullable_tuple_plain; +CREATE TABLE test_nullable_tuple_plain (c0 Tuple(UInt32, String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_plain VALUES ((1, 'a')), ((2, 'b')); + +-- Arrow plain +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_plain.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_plain; +SELECT c0 FROM file(currentDatabase() || '_04064_plain.arrow', 'Arrow', 'c0 Tuple(UInt32, String)'); + +-- ORC plain +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_plain.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_plain; +SELECT c0 FROM file(currentDatabase() || '_04064_plain.orc', 'ORC', 'c0 Tuple(UInt32, String)'); + +-- ORC legacy plain +SELECT c0 FROM file(currentDatabase() || '_04064_plain.orc', 'ORC', 'c0 Tuple(UInt32, String)') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_plain; + +-- Nested tuple inside nullable struct +DROP TABLE IF EXISTS test_nullable_tuple_nested; +CREATE TABLE test_nullable_tuple_nested (c0 Nullable(Tuple(Tuple(UInt32, String), UInt64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_nested VALUES (((1, 'a'), 10)), (NULL), (((3, 'c'), 30)); + +-- Arrow nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_nested.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_nested; +SELECT c0 FROM file(currentDatabase() || '_04064_nested.arrow', 'Arrow', 'c0 Nullable(Tuple(Tuple(UInt32, String), UInt64))'); + +-- ORC nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_nested.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_nested; +SELECT c0 FROM file(currentDatabase() || '_04064_nested.orc', 'ORC', 'c0 Nullable(Tuple(Tuple(UInt32, String), UInt64))'); + +-- ORC legacy nested +SELECT c0 FROM file(currentDatabase() || '_04064_nested.orc', 'ORC', 'c0 Nullable(Tuple(Tuple(UInt32, String), UInt64))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_nested; + +-- Schema inference without type hint +DROP TABLE IF EXISTS test_nullable_tuple_infer; +CREATE TABLE test_nullable_tuple_infer (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_infer VALUES ((1, 'a')), (NULL), ((3, 'c')); + +-- Arrow infer +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_infer.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_infer; +SELECT c0 FROM file(currentDatabase() || '_04064_infer.arrow', 'Arrow'); + +-- ORC infer +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_infer.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_infer; +SELECT c0 FROM file(currentDatabase() || '_04064_infer.orc', 'ORC'); + +-- ORC legacy infer +SELECT c0 FROM file(currentDatabase() || '_04064_infer.orc', 'ORC') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_infer; + +-- Named tuple +DROP TABLE IF EXISTS test_nullable_tuple_named; +CREATE TABLE test_nullable_tuple_named (c0 Nullable(Tuple(a UInt32, b String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_named VALUES ((1, 'x')), (NULL), ((3, 'z')); + +-- Arrow named +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_named.arrow', 'Arrow', 'c0 Nullable(Tuple(a UInt32, b String))') SELECT c0 FROM test_nullable_tuple_named; +SELECT c0 FROM file(currentDatabase() || '_04064_named.arrow', 'Arrow', 'c0 Nullable(Tuple(a UInt32, b String))'); + +-- ORC named +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_named.orc', 'ORC', 'c0 Nullable(Tuple(a UInt32, b String))') SELECT c0 FROM test_nullable_tuple_named; +SELECT c0 FROM file(currentDatabase() || '_04064_named.orc', 'ORC', 'c0 Nullable(Tuple(a UInt32, b String))'); + +-- ORC legacy named +SELECT c0 FROM file(currentDatabase() || '_04064_named.orc', 'ORC', 'c0 Nullable(Tuple(a UInt32, b String))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_named; + +-- All-NULL column +DROP TABLE IF EXISTS test_nullable_tuple_allnull; +CREATE TABLE test_nullable_tuple_allnull (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_allnull VALUES (NULL), (NULL), (NULL); + +-- Arrow all null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_allnull.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_allnull; +SELECT c0 FROM file(currentDatabase() || '_04064_allnull.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ORC all null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_allnull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_allnull; +SELECT c0 FROM file(currentDatabase() || '_04064_allnull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ORC legacy all null +SELECT c0 FROM file(currentDatabase() || '_04064_allnull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_allnull; + +-- No-NULL column (nullable type, zero actual NULLs) +DROP TABLE IF EXISTS test_nullable_tuple_nonull; +CREATE TABLE test_nullable_tuple_nonull (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_nonull VALUES ((1, 'a')), ((2, 'b')), ((3, 'c')); + +-- Arrow no null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_nonull.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_nonull; +SELECT c0 FROM file(currentDatabase() || '_04064_nonull.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ORC no null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_nonull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_nonull; +SELECT c0 FROM file(currentDatabase() || '_04064_nonull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ORC legacy no null +SELECT c0 FROM file(currentDatabase() || '_04064_nonull.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_nonull; + +-- Single-element tuple +DROP TABLE IF EXISTS test_nullable_tuple_single; +CREATE TABLE test_nullable_tuple_single (c0 Nullable(Tuple(UInt32))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_single VALUES ((1,)), (NULL), ((3,)); + +-- Arrow single +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_single.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32))') SELECT c0 FROM test_nullable_tuple_single; +SELECT c0 FROM file(currentDatabase() || '_04064_single.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32))'); + +-- ORC single +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_single.orc', 'ORC', 'c0 Nullable(Tuple(UInt32))') SELECT c0 FROM test_nullable_tuple_single; +SELECT c0 FROM file(currentDatabase() || '_04064_single.orc', 'ORC', 'c0 Nullable(Tuple(UInt32))'); + +-- ORC legacy single +SELECT c0 FROM file(currentDatabase() || '_04064_single.orc', 'ORC', 'c0 Nullable(Tuple(UInt32))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_single; + +-- Deeply nested: nullable tuple inside nullable tuple +DROP TABLE IF EXISTS test_nullable_tuple_deep; +CREATE TABLE test_nullable_tuple_deep (c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_deep VALUES (((1, 'a'), 10)), (NULL), ((NULL, 20)), (((4, 'd'), 40)); + +-- Arrow deep nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_deep.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_deep; +SELECT c0 FROM file(currentDatabase() || '_04064_deep.arrow', 'Arrow', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))'); + +-- ORC deep nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_deep.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_deep; +SELECT c0 FROM file(currentDatabase() || '_04064_deep.orc', 'ORC', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))'); + +-- ORC legacy deep nested +SELECT c0 FROM file(currentDatabase() || '_04064_deep.orc', 'ORC', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_deep; + +-- Nullable tuple with Array element +DROP TABLE IF EXISTS test_nullable_tuple_arr; +CREATE TABLE test_nullable_tuple_arr (c0 Nullable(Tuple(Array(UInt32), String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr VALUES (([1, 2], 'a')), (NULL), (([3], 'c')); + +-- Arrow array elem +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_arr; +SELECT c0 FROM file(currentDatabase() || '_04064_arr.arrow', 'Arrow', 'c0 Nullable(Tuple(Array(UInt32), String))'); + +-- ORC array elem +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_arr; +SELECT c0 FROM file(currentDatabase() || '_04064_arr.orc', 'ORC', 'c0 Nullable(Tuple(Array(UInt32), String))'); + +-- ORC legacy array elem +SELECT c0 FROM file(currentDatabase() || '_04064_arr.orc', 'ORC', 'c0 Nullable(Tuple(Array(UInt32), String))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_arr; + +-- Multiple nullable tuple columns +DROP TABLE IF EXISTS test_nullable_tuple_multi; +CREATE TABLE test_nullable_tuple_multi (c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_multi VALUES ((1, 'a'), (1.5)), (NULL, (2.5)), ((3, 'c'), NULL); + +-- Arrow multi col +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_multi.arrow', 'Arrow') SELECT c0, c1 FROM test_nullable_tuple_multi; +SELECT c0, c1 FROM file(currentDatabase() || '_04064_multi.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))'); + +-- ORC multi col +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_multi.orc', 'ORC') SELECT c0, c1 FROM test_nullable_tuple_multi; +SELECT c0, c1 FROM file(currentDatabase() || '_04064_multi.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))'); + +-- ORC legacy multi col +SELECT c0, c1 FROM file(currentDatabase() || '_04064_multi.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_multi; + +-- Type hint mismatch: file has Nullable(Tuple(...)), read as Tuple(...) (strip nullable, NULLs become defaults) +DROP TABLE IF EXISTS test_nullable_tuple_mismatch1; +CREATE TABLE test_nullable_tuple_mismatch1 (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_mismatch1 VALUES ((1, 'a')), (NULL), ((3, 'c')); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_mismatch1.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_mismatch1; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_mismatch1.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_mismatch1; + +-- Arrow: read nullable file as non-nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch1.arrow', 'Arrow', 'c0 Tuple(UInt32, String)'); + +-- ORC: read nullable file as non-nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch1.orc', 'ORC', 'c0 Tuple(UInt32, String)'); + +-- ORC legacy: read nullable file as non-nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch1.orc', 'ORC', 'c0 Tuple(UInt32, String)') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_mismatch1; + +-- Type hint mismatch: file has Tuple(...), read as Nullable(Tuple(...)) (add nullable wrapper) +DROP TABLE IF EXISTS test_nullable_tuple_mismatch2; +CREATE TABLE test_nullable_tuple_mismatch2 (c0 Tuple(UInt32, String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_mismatch2 VALUES ((1, 'a')), ((2, 'b')); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_mismatch2.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_mismatch2; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_mismatch2.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_mismatch2; + +-- Arrow: read non-nullable file as nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch2.arrow', 'Arrow', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ORC: read non-nullable file as nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch2.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))'); + +-- ORC legacy: read non-nullable file as nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_mismatch2.orc', 'ORC', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_mismatch2; + +-- Schema inference: DESCRIBE without type hint shows inferred type +DROP TABLE IF EXISTS test_nullable_tuple_describe; +CREATE TABLE test_nullable_tuple_describe (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_describe VALUES ((1, 'a')), (NULL), ((3, 'c')); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_describe.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_describe; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_describe.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_describe; + +-- Arrow: inferred type +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_describe.arrow', 'Arrow'); + +-- ORC: inferred type +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_describe.orc', 'ORC'); + +DROP TABLE test_nullable_tuple_describe; + +-- Array(Nullable(Tuple)) flattened via import_nested: struct-level NULLs should propagate to elements +DROP TABLE IF EXISTS test_nullable_tuple_import_nested; +CREATE TABLE test_nullable_tuple_import_nested (c0 Array(Nullable(Tuple(a UInt32, b String)))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_import_nested VALUES ([(1, 'a'), NULL, (3, 'c')]); + +-- Arrow import_nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_import_nested.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_import_nested; +SELECT * FROM file(currentDatabase() || '_04064_import_nested.arrow', 'Arrow', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_arrow_import_nested = 1; + +-- ORC import_nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_import_nested.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_import_nested; +SELECT * FROM file(currentDatabase() || '_04064_import_nested.orc', 'ORC', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_orc_import_nested = 1; + +-- ORC legacy import_nested +SELECT * FROM file(currentDatabase() || '_04064_import_nested.orc', 'ORC', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_orc_import_nested = 1, input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_import_nested; + +-- Array(Nullable(Tuple)) without named elements: round-trip as a single column, no flattening +DROP TABLE IF EXISTS test_nullable_tuple_arr_unnamed; +CREATE TABLE test_nullable_tuple_arr_unnamed (c0 Array(Nullable(Tuple(UInt32, String)))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr_unnamed VALUES ([(1, 'a'), NULL, (3, 'c')]); + +-- Arrow unnamed +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr_unnamed.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_arr_unnamed; +SELECT c0 FROM file(currentDatabase() || '_04064_arr_unnamed.arrow', 'Arrow', 'c0 Array(Nullable(Tuple(UInt32, String)))'); + +-- ORC unnamed +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr_unnamed.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_arr_unnamed; +SELECT c0 FROM file(currentDatabase() || '_04064_arr_unnamed.orc', 'ORC', 'c0 Array(Nullable(Tuple(UInt32, String)))'); + +-- ORC legacy unnamed +SELECT c0 FROM file(currentDatabase() || '_04064_arr_unnamed.orc', 'ORC', 'c0 Array(Nullable(Tuple(UInt32, String)))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_arr_unnamed; + +-- Array(Nullable(Tuple)) with Array element inside: import_nested flattens, Array defaults to [] at null positions +DROP TABLE IF EXISTS test_nullable_tuple_arr_nested_elem; +CREATE TABLE test_nullable_tuple_arr_nested_elem (c0 Array(Nullable(Tuple(a UInt32, b Array(UInt32))))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr_nested_elem VALUES ([(1, [10, 20]), NULL, (3, [30])]); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr_nested_elem.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_arr_nested_elem; + +-- Arrow import_nested: scalar becomes Nullable, Array defaults to [] at null struct positions +SELECT * FROM file(currentDatabase() || '_04064_arr_nested_elem.arrow', 'Arrow', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_arrow_import_nested = 1; + +-- ORC import_nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_arr_nested_elem.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_arr_nested_elem; +SELECT * FROM file(currentDatabase() || '_04064_arr_nested_elem.orc', 'ORC', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_orc_import_nested = 1; + +-- ORC legacy import_nested +SELECT * FROM file(currentDatabase() || '_04064_arr_nested_elem.orc', 'ORC', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_orc_import_nested = 1, input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_arr_nested_elem; + +-- LowCardinality(Nullable(String)) hint with no physical nulls in the file: the ORC reader must still wrap the column as nullable +DROP TABLE IF EXISTS test_nullable_tuple_lc_string; +CREATE TABLE test_nullable_tuple_lc_string (c0 String) ENGINE = Memory; +INSERT INTO test_nullable_tuple_lc_string VALUES ('hello'), ('world'); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_lc_str.arrow', 'Arrow') SELECT c0 FROM test_nullable_tuple_lc_string; +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04064_lc_str.orc', 'ORC') SELECT c0 FROM test_nullable_tuple_lc_string; + +-- Arrow: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_lc_str.arrow', 'Arrow', 'c0 LowCardinality(Nullable(String))'); + +-- ORC: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_lc_str.orc', 'ORC', 'c0 LowCardinality(Nullable(String))'); + +-- ORC legacy: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04064_lc_str.orc', 'ORC', 'c0 LowCardinality(Nullable(String))') SETTINGS input_format_orc_use_fast_decoder = 0; + +DROP TABLE test_nullable_tuple_lc_string; diff --git a/tests/queries/0_stateless/04065_tuple_inside_nullable_parquet_roundtrip.reference b/tests/queries/0_stateless/04065_tuple_inside_nullable_parquet_roundtrip.reference new file mode 100644 index 000000000000..9ea83fc503fe --- /dev/null +++ b/tests/queries/0_stateless/04065_tuple_inside_nullable_parquet_roundtrip.reference @@ -0,0 +1,251 @@ +-- { echo } + +SET allow_experimental_nullable_tuple_type = 1; +SET engine_file_truncate_on_insert = 1; +-- Nullable struct with non-nullable elements +DROP TABLE IF EXISTS test_nullable_tuple_basic; +CREATE TABLE test_nullable_tuple_basic (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_basic VALUES ((1, 'a')), (NULL), ((3, 'c')); +-- Parquet Arrow reader +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_basic; +SELECT c0 FROM file(currentDatabase() || '_04065.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') +\N +(3,'c') +-- Parquet V3 native reader (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_basic; +-- Both struct and element nullable: Nullable(Tuple(Nullable(UInt32), String)) +DROP TABLE IF EXISTS test_nullable_tuple_both; +CREATE TABLE test_nullable_tuple_both (c0 Nullable(Tuple(Nullable(UInt32), String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_both VALUES ((1, 'a')), (NULL), ((NULL, 'c')), ((4, 'd')); +-- Parquet Arrow reader both nullable +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_both.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_both; +SELECT c0 FROM file(currentDatabase() || '_04065_both.parquet', 'Parquet', 'c0 Nullable(Tuple(Nullable(UInt32), String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') +\N +(NULL,'c') +(4,'d') +-- Parquet V3 native reader (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_both.parquet', 'Parquet', 'c0 Nullable(Tuple(Nullable(UInt32), String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_both; +-- Non-nullable struct with nullable elements +DROP TABLE IF EXISTS test_nullable_tuple_elem; +CREATE TABLE test_nullable_tuple_elem (c0 Tuple(Nullable(UInt32), String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_elem VALUES ((1, 'a')), ((NULL, 'b')); +-- Parquet Arrow reader nullable elements +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_elem.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_elem; +SELECT c0 FROM file(currentDatabase() || '_04065_elem.parquet', 'Parquet', 'c0 Tuple(Nullable(UInt32), String)') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') +(NULL,'b') +-- Parquet V3 native reader nullable elements +SELECT c0 FROM file(currentDatabase() || '_04065_elem.parquet', 'Parquet', 'c0 Tuple(Nullable(UInt32), String)') SETTINGS input_format_parquet_use_native_reader_v3 = 1; +(1,'a') +(NULL,'b') +DROP TABLE test_nullable_tuple_elem; +-- Plain non-nullable tuple +DROP TABLE IF EXISTS test_nullable_tuple_plain; +CREATE TABLE test_nullable_tuple_plain (c0 Tuple(UInt32, String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_plain VALUES ((1, 'a')), ((2, 'b')); +-- Parquet Arrow reader plain +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_plain.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_plain; +SELECT c0 FROM file(currentDatabase() || '_04065_plain.parquet', 'Parquet', 'c0 Tuple(UInt32, String)') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') +(2,'b') +-- Parquet V3 native reader plain +SELECT c0 FROM file(currentDatabase() || '_04065_plain.parquet', 'Parquet', 'c0 Tuple(UInt32, String)') SETTINGS input_format_parquet_use_native_reader_v3 = 1; +(1,'a') +(2,'b') +DROP TABLE test_nullable_tuple_plain; +-- Named tuple +DROP TABLE IF EXISTS test_nullable_tuple_named; +CREATE TABLE test_nullable_tuple_named (c0 Nullable(Tuple(a UInt32, b String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_named VALUES ((1, 'x')), (NULL), ((3, 'z')); +-- Parquet Arrow reader named +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_named.parquet', 'Parquet', 'c0 Nullable(Tuple(a UInt32, b String))') SELECT c0 FROM test_nullable_tuple_named; +SELECT c0 FROM file(currentDatabase() || '_04065_named.parquet', 'Parquet', 'c0 Nullable(Tuple(a UInt32, b String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'x') +\N +(3,'z') +-- Parquet V3 native reader named (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_named.parquet', 'Parquet', 'c0 Nullable(Tuple(a UInt32, b String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_named; +-- All-NULL column +DROP TABLE IF EXISTS test_nullable_tuple_allnull; +CREATE TABLE test_nullable_tuple_allnull (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_allnull VALUES (NULL), (NULL), (NULL); +-- Parquet Arrow reader all null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_allnull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_allnull; +SELECT c0 FROM file(currentDatabase() || '_04065_allnull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +\N +\N +\N +-- Parquet V3 native reader all null (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_allnull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_allnull; +-- No-NULL column (nullable type, zero actual NULLs) +DROP TABLE IF EXISTS test_nullable_tuple_nonull; +CREATE TABLE test_nullable_tuple_nonull (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_nonull VALUES ((1, 'a')), ((2, 'b')), ((3, 'c')); +-- Parquet Arrow reader no null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_nonull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_nonull; +SELECT c0 FROM file(currentDatabase() || '_04065_nonull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') +(2,'b') +(3,'c') +-- Parquet V3 native reader no null (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_nonull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_nonull; +-- Single-element tuple +DROP TABLE IF EXISTS test_nullable_tuple_single; +CREATE TABLE test_nullable_tuple_single (c0 Nullable(Tuple(UInt32))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_single VALUES ((1,)), (NULL), ((3,)); +-- Parquet Arrow reader single +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_single.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32))') SELECT c0 FROM test_nullable_tuple_single; +SELECT c0 FROM file(currentDatabase() || '_04065_single.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1) +\N +(3) +-- Parquet V3 native reader single (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_single.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_single; +-- Deeply nested: nullable tuple inside nullable tuple +DROP TABLE IF EXISTS test_nullable_tuple_deep; +CREATE TABLE test_nullable_tuple_deep (c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_deep VALUES (((1, 'a'), 10)), (NULL), ((NULL, 20)), (((4, 'd'), 40)); +-- Parquet Arrow reader deep nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_deep.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_deep; +SELECT c0 FROM file(currentDatabase() || '_04065_deep.parquet', 'Parquet', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +((1,'a'),10) +\N +(NULL,20) +((4,'d'),40) +-- Parquet V3 native reader deep nested (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_deep.parquet', 'Parquet', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_deep; +-- Nullable tuple with Array element +DROP TABLE IF EXISTS test_nullable_tuple_arr; +CREATE TABLE test_nullable_tuple_arr (c0 Nullable(Tuple(Array(UInt32), String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr VALUES (([1, 2], 'a')), (NULL), (([3], 'c')); +-- Parquet Arrow reader array elem +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_arr.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_arr; +SELECT c0 FROM file(currentDatabase() || '_04065_arr.parquet', 'Parquet', 'c0 Nullable(Tuple(Array(UInt32), String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +([1,2],'a') +\N +([3],'c') +-- Parquet V3 native reader array elem (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_arr.parquet', 'Parquet', 'c0 Nullable(Tuple(Array(UInt32), String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_arr; +-- Multiple nullable tuple columns +DROP TABLE IF EXISTS test_nullable_tuple_multi; +CREATE TABLE test_nullable_tuple_multi (c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_multi VALUES ((1, 'a'), (1.5)), (NULL, (2.5)), ((3, 'c'), NULL); +-- Parquet Arrow reader multi col +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_multi.parquet', 'Parquet') SELECT c0, c1 FROM test_nullable_tuple_multi; +SELECT c0, c1 FROM file(currentDatabase() || '_04065_multi.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') (1.5) +\N (2.5) +(3,'c') \N +-- Parquet V3 native reader multi col (not yet supported) +SELECT c0, c1 FROM file(currentDatabase() || '_04065_multi.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_multi; +-- Schema inference without type hint (works for both readers, but V3 loses struct-level NULL) +DROP TABLE IF EXISTS test_nullable_tuple_infer; +CREATE TABLE test_nullable_tuple_infer (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_infer VALUES ((1, 'a')), (NULL), ((3, 'c')); +-- Parquet Arrow reader infer +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_infer.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_infer; +SELECT c0 FROM file(currentDatabase() || '_04065_infer.parquet', 'Parquet') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') +\N +(3,'c') +DROP TABLE test_nullable_tuple_infer; +-- Type hint mismatch: file has Nullable(Tuple(...)), read as Tuple(...) (strip nullable, NULLs become defaults) +DROP TABLE IF EXISTS test_nullable_tuple_mismatch1; +CREATE TABLE test_nullable_tuple_mismatch1 (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_mismatch1 VALUES ((1, 'a')), (NULL), ((3, 'c')); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_mismatch1.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_mismatch1; +-- Parquet Arrow reader: read nullable file as non-nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_mismatch1.parquet', 'Parquet', 'c0 Tuple(UInt32, String)') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') Tuple(UInt32, String) +(0,'') Tuple(UInt32, String) +(3,'c') Tuple(UInt32, String) +DROP TABLE test_nullable_tuple_mismatch1; +-- Type hint mismatch: file has Tuple(...), read as Nullable(Tuple(...)) (add nullable wrapper) +DROP TABLE IF EXISTS test_nullable_tuple_mismatch2; +CREATE TABLE test_nullable_tuple_mismatch2 (c0 Tuple(UInt32, String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_mismatch2 VALUES ((1, 'a')), ((2, 'b')); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_mismatch2.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_mismatch2; +-- Parquet Arrow reader: read non-nullable file as nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_mismatch2.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') Nullable(Tuple(UInt32, String)) +(2,'b') Nullable(Tuple(UInt32, String)) +-- Parquet V3 native reader: read non-nullable file as nullable (not yet supported) +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_mismatch2.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_mismatch2; +-- Schema inference: inferred type with toTypeName +DROP TABLE IF EXISTS test_nullable_tuple_describe; +CREATE TABLE test_nullable_tuple_describe (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_describe VALUES ((1, 'a')), (NULL), ((3, 'c')); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_describe.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_describe; +-- Parquet Arrow reader: inferred type +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_describe.parquet', 'Parquet') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +(1,'a') Nullable(Tuple(`1` UInt32, `2` String)) +\N Nullable(Tuple(`1` UInt32, `2` String)) +(3,'c') Nullable(Tuple(`1` UInt32, `2` String)) +-- Parquet V3 native reader: inferred type (struct-level NULL not supported, becomes (NULL,NULL)) +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_describe.parquet', 'Parquet') SETTINGS input_format_parquet_use_native_reader_v3 = 1; +(1,'a') Tuple(\n `1` Nullable(UInt32),\n `2` Nullable(String)) +(NULL,NULL) Tuple(\n `1` Nullable(UInt32),\n `2` Nullable(String)) +(3,'c') Tuple(\n `1` Nullable(UInt32),\n `2` Nullable(String)) +DROP TABLE test_nullable_tuple_describe; +-- Array(Nullable(Tuple)) flattened via import_nested: struct-level NULLs should propagate to elements +DROP TABLE IF EXISTS test_nullable_tuple_import_nested; +CREATE TABLE test_nullable_tuple_import_nested (c0 Array(Nullable(Tuple(a UInt32, b String)))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_import_nested VALUES ([(1, 'a'), NULL, (3, 'c')]); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_import_nested.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_import_nested; +-- Parquet Arrow reader import_nested +SELECT * FROM file(currentDatabase() || '_04065_import_nested.parquet', 'Parquet', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0, input_format_parquet_import_nested = 1; +[1,NULL,3] ['a',NULL,'c'] +-- Parquet V3 native reader import_nested +-- This works because V3 reader sees the already-flattened column names (c0.a, c0.b), not the Nullable(Tuple(...)) +SELECT * FROM file(currentDatabase() || '_04065_import_nested.parquet', 'Parquet', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1, input_format_parquet_import_nested = 1; +[1,NULL,3] ['a',NULL,'c'] +DROP TABLE test_nullable_tuple_import_nested; +-- Array(Nullable(Tuple)) without named elements: round-trip as a single column, no flattening +DROP TABLE IF EXISTS test_nullable_tuple_arr_unnamed; +CREATE TABLE test_nullable_tuple_arr_unnamed (c0 Array(Nullable(Tuple(UInt32, String)))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr_unnamed VALUES ([(1, 'a'), NULL, (3, 'c')]); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_arr_unnamed.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_arr_unnamed; +-- Parquet Arrow reader unnamed +SELECT c0 FROM file(currentDatabase() || '_04065_arr_unnamed.parquet', 'Parquet', 'c0 Array(Nullable(Tuple(UInt32, String)))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +[(1,'a'),NULL,(3,'c')] +-- Parquet V3 native reader unnamed (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_arr_unnamed.parquet', 'Parquet', 'c0 Array(Nullable(Tuple(UInt32, String)))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } +DROP TABLE test_nullable_tuple_arr_unnamed; +-- Array(Nullable(Tuple)) with Array element inside: import_nested flattens, Array defaults to [] at null positions +DROP TABLE IF EXISTS test_nullable_tuple_arr_nested_elem; +CREATE TABLE test_nullable_tuple_arr_nested_elem (c0 Array(Nullable(Tuple(a UInt32, b Array(UInt32))))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr_nested_elem VALUES ([(1, [10, 20]), NULL, (3, [30])]); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_arr_nested_elem.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_arr_nested_elem; +-- Parquet Arrow reader import_nested: scalar becomes Nullable, Array defaults to [] at null struct positions +SELECT * FROM file(currentDatabase() || '_04065_arr_nested_elem.parquet', 'Parquet', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_parquet_use_native_reader_v3 = 0, input_format_parquet_import_nested = 1; +[1,NULL,3] [[10,20],[],[30]] +-- Parquet V3 native reader import_nested +SELECT * FROM file(currentDatabase() || '_04065_arr_nested_elem.parquet', 'Parquet', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_parquet_use_native_reader_v3 = 1, input_format_parquet_import_nested = 1; +[1,NULL,3] [[10,20],[],[30]] +DROP TABLE test_nullable_tuple_arr_nested_elem; +-- LowCardinality(Nullable(String)) hint with no physical nulls in the file: the reader must still wrap the column as nullable +DROP TABLE IF EXISTS test_nullable_tuple_lc_string; +CREATE TABLE test_nullable_tuple_lc_string (c0 String) ENGINE = Memory; +INSERT INTO test_nullable_tuple_lc_string VALUES ('hello'), ('world'); +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_lc_str.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_lc_string; +-- Parquet Arrow reader: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_lc_str.parquet', 'Parquet', 'c0 LowCardinality(Nullable(String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; +hello LowCardinality(Nullable(String)) +world LowCardinality(Nullable(String)) +-- Parquet V3 native reader: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_lc_str.parquet', 'Parquet', 'c0 LowCardinality(Nullable(String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; +hello LowCardinality(Nullable(String)) +world LowCardinality(Nullable(String)) +DROP TABLE test_nullable_tuple_lc_string; diff --git a/tests/queries/0_stateless/04065_tuple_inside_nullable_parquet_roundtrip.sql b/tests/queries/0_stateless/04065_tuple_inside_nullable_parquet_roundtrip.sql new file mode 100644 index 000000000000..048970dfe0a4 --- /dev/null +++ b/tests/queries/0_stateless/04065_tuple_inside_nullable_parquet_roundtrip.sql @@ -0,0 +1,275 @@ +-- Tags: no-fasttest +-- no-fasttest: Parquet format is not available in fasttest builds + +-- { echo } + +SET allow_experimental_nullable_tuple_type = 1; +SET engine_file_truncate_on_insert = 1; + +-- Nullable struct with non-nullable elements +DROP TABLE IF EXISTS test_nullable_tuple_basic; +CREATE TABLE test_nullable_tuple_basic (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_basic VALUES ((1, 'a')), (NULL), ((3, 'c')); + +-- Parquet Arrow reader +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_basic; +SELECT c0 FROM file(currentDatabase() || '_04065.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_basic; + +-- Both struct and element nullable: Nullable(Tuple(Nullable(UInt32), String)) +DROP TABLE IF EXISTS test_nullable_tuple_both; +CREATE TABLE test_nullable_tuple_both (c0 Nullable(Tuple(Nullable(UInt32), String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_both VALUES ((1, 'a')), (NULL), ((NULL, 'c')), ((4, 'd')); + +-- Parquet Arrow reader both nullable +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_both.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_both; +SELECT c0 FROM file(currentDatabase() || '_04065_both.parquet', 'Parquet', 'c0 Nullable(Tuple(Nullable(UInt32), String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_both.parquet', 'Parquet', 'c0 Nullable(Tuple(Nullable(UInt32), String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_both; + +-- Non-nullable struct with nullable elements +DROP TABLE IF EXISTS test_nullable_tuple_elem; +CREATE TABLE test_nullable_tuple_elem (c0 Tuple(Nullable(UInt32), String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_elem VALUES ((1, 'a')), ((NULL, 'b')); + +-- Parquet Arrow reader nullable elements +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_elem.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_elem; +SELECT c0 FROM file(currentDatabase() || '_04065_elem.parquet', 'Parquet', 'c0 Tuple(Nullable(UInt32), String)') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader nullable elements +SELECT c0 FROM file(currentDatabase() || '_04065_elem.parquet', 'Parquet', 'c0 Tuple(Nullable(UInt32), String)') SETTINGS input_format_parquet_use_native_reader_v3 = 1; + +DROP TABLE test_nullable_tuple_elem; + +-- Plain non-nullable tuple +DROP TABLE IF EXISTS test_nullable_tuple_plain; +CREATE TABLE test_nullable_tuple_plain (c0 Tuple(UInt32, String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_plain VALUES ((1, 'a')), ((2, 'b')); + +-- Parquet Arrow reader plain +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_plain.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_plain; +SELECT c0 FROM file(currentDatabase() || '_04065_plain.parquet', 'Parquet', 'c0 Tuple(UInt32, String)') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader plain +SELECT c0 FROM file(currentDatabase() || '_04065_plain.parquet', 'Parquet', 'c0 Tuple(UInt32, String)') SETTINGS input_format_parquet_use_native_reader_v3 = 1; + +DROP TABLE test_nullable_tuple_plain; + +-- Named tuple +DROP TABLE IF EXISTS test_nullable_tuple_named; +CREATE TABLE test_nullable_tuple_named (c0 Nullable(Tuple(a UInt32, b String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_named VALUES ((1, 'x')), (NULL), ((3, 'z')); + +-- Parquet Arrow reader named +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_named.parquet', 'Parquet', 'c0 Nullable(Tuple(a UInt32, b String))') SELECT c0 FROM test_nullable_tuple_named; +SELECT c0 FROM file(currentDatabase() || '_04065_named.parquet', 'Parquet', 'c0 Nullable(Tuple(a UInt32, b String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader named (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_named.parquet', 'Parquet', 'c0 Nullable(Tuple(a UInt32, b String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_named; + +-- All-NULL column +DROP TABLE IF EXISTS test_nullable_tuple_allnull; +CREATE TABLE test_nullable_tuple_allnull (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_allnull VALUES (NULL), (NULL), (NULL); + +-- Parquet Arrow reader all null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_allnull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_allnull; +SELECT c0 FROM file(currentDatabase() || '_04065_allnull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader all null (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_allnull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_allnull; + +-- No-NULL column (nullable type, zero actual NULLs) +DROP TABLE IF EXISTS test_nullable_tuple_nonull; +CREATE TABLE test_nullable_tuple_nonull (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_nonull VALUES ((1, 'a')), ((2, 'b')), ((3, 'c')); + +-- Parquet Arrow reader no null +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_nonull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SELECT c0 FROM test_nullable_tuple_nonull; +SELECT c0 FROM file(currentDatabase() || '_04065_nonull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader no null (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_nonull.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_nonull; + +-- Single-element tuple +DROP TABLE IF EXISTS test_nullable_tuple_single; +CREATE TABLE test_nullable_tuple_single (c0 Nullable(Tuple(UInt32))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_single VALUES ((1,)), (NULL), ((3,)); + +-- Parquet Arrow reader single +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_single.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32))') SELECT c0 FROM test_nullable_tuple_single; +SELECT c0 FROM file(currentDatabase() || '_04065_single.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader single (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_single.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_single; + +-- Deeply nested: nullable tuple inside nullable tuple +DROP TABLE IF EXISTS test_nullable_tuple_deep; +CREATE TABLE test_nullable_tuple_deep (c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_deep VALUES (((1, 'a'), 10)), (NULL), ((NULL, 20)), (((4, 'd'), 40)); + +-- Parquet Arrow reader deep nested +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_deep.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_deep; +SELECT c0 FROM file(currentDatabase() || '_04065_deep.parquet', 'Parquet', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader deep nested (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_deep.parquet', 'Parquet', 'c0 Nullable(Tuple(Nullable(Tuple(UInt32, String)), UInt64))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_deep; + +-- Nullable tuple with Array element +DROP TABLE IF EXISTS test_nullable_tuple_arr; +CREATE TABLE test_nullable_tuple_arr (c0 Nullable(Tuple(Array(UInt32), String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr VALUES (([1, 2], 'a')), (NULL), (([3], 'c')); + +-- Parquet Arrow reader array elem +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_arr.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_arr; +SELECT c0 FROM file(currentDatabase() || '_04065_arr.parquet', 'Parquet', 'c0 Nullable(Tuple(Array(UInt32), String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader array elem (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_arr.parquet', 'Parquet', 'c0 Nullable(Tuple(Array(UInt32), String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_arr; + +-- Multiple nullable tuple columns +DROP TABLE IF EXISTS test_nullable_tuple_multi; +CREATE TABLE test_nullable_tuple_multi (c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_multi VALUES ((1, 'a'), (1.5)), (NULL, (2.5)), ((3, 'c'), NULL); + +-- Parquet Arrow reader multi col +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_multi.parquet', 'Parquet') SELECT c0, c1 FROM test_nullable_tuple_multi; +SELECT c0, c1 FROM file(currentDatabase() || '_04065_multi.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader multi col (not yet supported) +SELECT c0, c1 FROM file(currentDatabase() || '_04065_multi.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String)), c1 Nullable(Tuple(Float64))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_multi; + +-- Schema inference without type hint (works for both readers, but V3 loses struct-level NULL) +DROP TABLE IF EXISTS test_nullable_tuple_infer; +CREATE TABLE test_nullable_tuple_infer (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_infer VALUES ((1, 'a')), (NULL), ((3, 'c')); + +-- Parquet Arrow reader infer +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_infer.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_infer; +SELECT c0 FROM file(currentDatabase() || '_04065_infer.parquet', 'Parquet') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +DROP TABLE test_nullable_tuple_infer; + +-- Type hint mismatch: file has Nullable(Tuple(...)), read as Tuple(...) (strip nullable, NULLs become defaults) +DROP TABLE IF EXISTS test_nullable_tuple_mismatch1; +CREATE TABLE test_nullable_tuple_mismatch1 (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_mismatch1 VALUES ((1, 'a')), (NULL), ((3, 'c')); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_mismatch1.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_mismatch1; + +-- Parquet Arrow reader: read nullable file as non-nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_mismatch1.parquet', 'Parquet', 'c0 Tuple(UInt32, String)') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +DROP TABLE test_nullable_tuple_mismatch1; + +-- Type hint mismatch: file has Tuple(...), read as Nullable(Tuple(...)) (add nullable wrapper) +DROP TABLE IF EXISTS test_nullable_tuple_mismatch2; +CREATE TABLE test_nullable_tuple_mismatch2 (c0 Tuple(UInt32, String)) ENGINE = Memory; +INSERT INTO test_nullable_tuple_mismatch2 VALUES ((1, 'a')), ((2, 'b')); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_mismatch2.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_mismatch2; + +-- Parquet Arrow reader: read non-nullable file as nullable +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_mismatch2.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader: read non-nullable file as nullable (not yet supported) +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_mismatch2.parquet', 'Parquet', 'c0 Nullable(Tuple(UInt32, String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_mismatch2; + +-- Schema inference: inferred type with toTypeName +DROP TABLE IF EXISTS test_nullable_tuple_describe; +CREATE TABLE test_nullable_tuple_describe (c0 Nullable(Tuple(UInt32, String))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_describe VALUES ((1, 'a')), (NULL), ((3, 'c')); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_describe.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_describe; + +-- Parquet Arrow reader: inferred type +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_describe.parquet', 'Parquet') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader: inferred type (struct-level NULL not supported, becomes (NULL,NULL)) +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_describe.parquet', 'Parquet') SETTINGS input_format_parquet_use_native_reader_v3 = 1; + +DROP TABLE test_nullable_tuple_describe; + +-- Array(Nullable(Tuple)) flattened via import_nested: struct-level NULLs should propagate to elements +DROP TABLE IF EXISTS test_nullable_tuple_import_nested; +CREATE TABLE test_nullable_tuple_import_nested (c0 Array(Nullable(Tuple(a UInt32, b String)))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_import_nested VALUES ([(1, 'a'), NULL, (3, 'c')]); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_import_nested.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_import_nested; + +-- Parquet Arrow reader import_nested +SELECT * FROM file(currentDatabase() || '_04065_import_nested.parquet', 'Parquet', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0, input_format_parquet_import_nested = 1; + +-- Parquet V3 native reader import_nested +-- This works because V3 reader sees the already-flattened column names (c0.a, c0.b), not the Nullable(Tuple(...)) +SELECT * FROM file(currentDatabase() || '_04065_import_nested.parquet', 'Parquet', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Nullable(String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1, input_format_parquet_import_nested = 1; + +DROP TABLE test_nullable_tuple_import_nested; + +-- Array(Nullable(Tuple)) without named elements: round-trip as a single column, no flattening +DROP TABLE IF EXISTS test_nullable_tuple_arr_unnamed; +CREATE TABLE test_nullable_tuple_arr_unnamed (c0 Array(Nullable(Tuple(UInt32, String)))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr_unnamed VALUES ([(1, 'a'), NULL, (3, 'c')]); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_arr_unnamed.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_arr_unnamed; + +-- Parquet Arrow reader unnamed +SELECT c0 FROM file(currentDatabase() || '_04065_arr_unnamed.parquet', 'Parquet', 'c0 Array(Nullable(Tuple(UInt32, String)))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader unnamed (not yet supported) +SELECT c0 FROM file(currentDatabase() || '_04065_arr_unnamed.parquet', 'Parquet', 'c0 Array(Nullable(Tuple(UInt32, String)))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; -- { serverError TYPE_MISMATCH } + +DROP TABLE test_nullable_tuple_arr_unnamed; + +-- Array(Nullable(Tuple)) with Array element inside: import_nested flattens, Array defaults to [] at null positions +DROP TABLE IF EXISTS test_nullable_tuple_arr_nested_elem; +CREATE TABLE test_nullable_tuple_arr_nested_elem (c0 Array(Nullable(Tuple(a UInt32, b Array(UInt32))))) ENGINE = Memory; +INSERT INTO test_nullable_tuple_arr_nested_elem VALUES ([(1, [10, 20]), NULL, (3, [30])]); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_arr_nested_elem.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_arr_nested_elem; + +-- Parquet Arrow reader import_nested: scalar becomes Nullable, Array defaults to [] at null struct positions +SELECT * FROM file(currentDatabase() || '_04065_arr_nested_elem.parquet', 'Parquet', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_parquet_use_native_reader_v3 = 0, input_format_parquet_import_nested = 1; + +-- Parquet V3 native reader import_nested +SELECT * FROM file(currentDatabase() || '_04065_arr_nested_elem.parquet', 'Parquet', '`c0.a` Array(Nullable(UInt32)), `c0.b` Array(Array(UInt32))') SETTINGS input_format_parquet_use_native_reader_v3 = 1, input_format_parquet_import_nested = 1; + +DROP TABLE test_nullable_tuple_arr_nested_elem; + +-- LowCardinality(Nullable(String)) hint with no physical nulls in the file: the reader must still wrap the column as nullable +DROP TABLE IF EXISTS test_nullable_tuple_lc_string; +CREATE TABLE test_nullable_tuple_lc_string (c0 String) ENGINE = Memory; +INSERT INTO test_nullable_tuple_lc_string VALUES ('hello'), ('world'); + +INSERT INTO TABLE FUNCTION file(currentDatabase() || '_04065_lc_str.parquet', 'Parquet') SELECT c0 FROM test_nullable_tuple_lc_string; + +-- Parquet Arrow reader: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_lc_str.parquet', 'Parquet', 'c0 LowCardinality(Nullable(String))') SETTINGS input_format_parquet_use_native_reader_v3 = 0; + +-- Parquet V3 native reader: no physical nulls, LowCardinality(Nullable(String)) hint +SELECT c0, toTypeName(c0) FROM file(currentDatabase() || '_04065_lc_str.parquet', 'Parquet', 'c0 LowCardinality(Nullable(String))') SETTINGS input_format_parquet_use_native_reader_v3 = 1; + +DROP TABLE test_nullable_tuple_lc_string; diff --git a/tmp/source_pr_chcolumn.diff b/tmp/source_pr_chcolumn.diff new file mode 100644 index 000000000000..770f93ddfec9 --- /dev/null +++ b/tmp/source_pr_chcolumn.diff @@ -0,0 +1,1048 @@ +commit e02e0dd65eb5d5a12f597fff1814dce54451936f +Merge: 02a33d49205 21831490ef1 +Author: Yakov Olkhovskiy <99031427+yakov-olkhovskiy@users.noreply.github.com> +Date: Tue Apr 14 20:20:08 2026 +0000 + + Merge pull request #91170 from ClickHouse/feat-arrowflight-impl + + Add Arrow Flight SQL support + +diff --git a/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp b/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp +index 57479db6ed3..c7a505807aa 100644 +--- a/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp ++++ b/src/Processors/Formats/Impl/CHColumnToArrowColumn.cpp +@@ -3,6 +3,7 @@ + #if USE_ARROW || USE_PARQUET + + #include ++#include + #include + #include + #include +@@ -10,6 +11,7 @@ + #include + #include + #include ++#include + #include + #include + #include +@@ -20,6 +22,7 @@ + #include + #include + #include ++#include + #include + #include + #include +@@ -31,8 +34,19 @@ + #include + #include + #include ++#include ++#include ++#include ++#include ++#include ++#include + #include + ++#include ++#include ++#include ++#include ++ + #define FOR_INTERNAL_NUMERIC_TYPES(M) \ + M(Int8, arrow::Int8Builder) \ + M(UInt16, arrow::UInt16Builder) \ +@@ -66,10 +80,10 @@ namespace DB + namespace ErrorCodes + { + extern const int UNKNOWN_EXCEPTION; +- extern const int UNKNOWN_TYPE; + extern const int LOGICAL_ERROR; + extern const int DECIMAL_OVERFLOW; + extern const int ILLEGAL_COLUMN; ++ extern const int UNKNOWN_TYPE; + } + + class ArrowUUIDExtensionType : public arrow::ExtensionType +@@ -134,6 +148,55 @@ namespace DB + throw Exception(ErrorCodes::UNKNOWN_EXCEPTION, "Error with a {} column \"{}\": {}.", format_name, column_name, status.ToString()); + } + ++ template ++ static ResultType checkResult(arrow::Result && result, const String & column_name, const String & format_name) ++ { ++ checkStatus(result.status(), column_name, format_name); ++ return std::move(result).ValueUnsafe(); ++ } ++ ++ static std::shared_ptr nullBytemapToArrowBitmap( ++ const PaddedPODArray * null_bytemap, ++ const String & column_name, ++ const String & format_name, ++ size_t start, ++ size_t end) ++ { ++ if (!null_bytemap) ++ return nullptr; ++ ++ int64_t length = static_cast(end - start); ++ auto bitmap = checkResult(arrow::AllocateEmptyBitmap(length), column_name, format_name); ++ auto * data = bitmap->mutable_data(); ++ for (size_t i = 0; i < static_cast(length); ++i) ++ { ++ if (!(*null_bytemap)[start + i]) ++ arrow::bit_util::SetBit(data, static_cast(i)); ++ } ++ return bitmap; ++ } ++ ++ static void fillArrowArrayWithRawColumnData( ++ ColumnPtr write_column, ++ const PaddedPODArray * null_bytemap, ++ const String & format_name, ++ arrow::ArrayBuilder* array_builder, ++ size_t start, ++ size_t end) ++ { ++ arrow::BinaryBuilder & builder = assert_cast(*array_builder); ++ arrow::Status status; ++ ++ for (size_t value_i = start; value_i < end; ++value_i) ++ { ++ if (null_bytemap && (*null_bytemap)[value_i]) ++ status = builder.AppendNull(); ++ else ++ status = builder.Append(write_column->getDataAt(value_i)); ++ checkStatus(status, write_column->getName(), format_name); ++ } ++ } ++ + /// Invert values since Arrow interprets 1 as a non-null value, while CH as a null + static PaddedPODArray revertNullByteMap(const PaddedPODArray * null_bytemap, size_t start, size_t end) + { +@@ -306,9 +369,9 @@ namespace DB + } + } + +- static void fillArrowArray( ++ static std::shared_ptr fillArrowArray( + const String & column_name, +- ColumnPtr & column, ++ ColumnPtr column, + const DataTypePtr & column_type, + const PaddedPODArray * null_bytemap, + arrow::ArrayBuilder * array_builder, +@@ -318,44 +381,166 @@ namespace DB + const CHColumnToArrowColumn::Settings & settings, + std::unordered_map & dictionary_values); + +- template +- static void fillArrowArrayWithArrayColumnData( +- const String & column_name, +- ColumnPtr & column, +- const DataTypePtr & column_type, +- const PaddedPODArray *, +- arrow::ArrayBuilder * array_builder, +- String format_name, ++ ++ static std::shared_ptr getArrowType( ++ DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable, bool for_builder = false); ++ ++ ++ static std::shared_ptr buildArrowDenseUnionArrayWithVariantColumnData( ++ const ColumnVariant & column, ++ const DataTypeVariant & column_type, ++ const PaddedPODArray * null_bytemap, ++ const String & format_name, + size_t start, + size_t end, + const CHColumnToArrowColumn::Settings & settings, + std::unordered_map & dictionary_values) + { +- const auto * column_array = assert_cast(column.get()); +- ColumnPtr nested_column = column_array->getDataPtr(); +- DataTypePtr nested_type = assert_cast(column_type.get())->getNestedType(); +- const auto & offsets = column_array->getOffsets(); ++ size_t size = end - start; ++ const auto & column_offsets = column.getOffsets(); ++ const auto & discriminators = column.getLocalDiscriminators(); ++ arrow::Int8Builder type_ids_builder; + +- Builder & builder = assert_cast(*array_builder); +- arrow::ArrayBuilder * value_builder = builder.value_builder(); +- arrow::Status components_status; ++ const auto num_variants = column.getNumVariants(); ++ if (num_variants > static_cast(std::numeric_limits::max())) ++ throw Exception( ++ ErrorCodes::ILLEGAL_COLUMN, ++ "Cannot convert Variant with {} nested types to {} Arrow DenseUnion: maximum supported is {} ", ++ num_variants, ++ format_name, ++ static_cast(std::numeric_limits::max())); + +- for (size_t array_idx = start; array_idx < end; ++array_idx) ++ std::vector starts(num_variants); ++ std::vector ends(num_variants); ++ arrow::Status status; ++ /// Here we are doing slicing - there is no clear specification on ColumnVariant having ++ /// offsets being monotonic and contiguous (though from current code it seems they are), ++ /// Arrow DenseUnion explicitly requires monotonicity, so we are going to tolerate non-contiguous ++ /// offsets, but raise an exception for violation of monotonicity. ++ for (size_t idx = start; idx < discriminators.size() && idx < end; ++idx) + { +- /// Start new array. +- components_status = builder.Append(); +- checkStatus(components_status, nested_column->getName(), format_name); +- +- /// Pass null null_map, because fillArrowArray will decide whether nested_type is nullable, if nullable, it will create a new null_map from nested_column +- /// Note that it is only needed by gluten(https://github.com/oap-project/gluten), because array type in gluten is by default nullable. +- /// And it does not influence the original ClickHouse logic, because null_map passed to fillArrowArrayWithArrayColumnData is always nullptr for ClickHouse doesn't allow nullable complex types including array type. +- fillArrowArray(column_name, nested_column, nested_type, nullptr, value_builder, format_name, offsets[array_idx - 1], offsets[array_idx], settings, dictionary_values); ++ const auto & discriminator = discriminators[idx]; ++ if (discriminator != ColumnVariant::NULL_DISCRIMINATOR) ++ { ++ auto global_discr = column.globalDiscriminatorByLocal(discriminator); ++ if (ends[global_discr] == 0) ++ starts[global_discr] = column_offsets[idx]; ++ else if (column_offsets[idx] < ends[global_discr]) ++ throw Exception( ++ ErrorCodes::ILLEGAL_COLUMN, ++ "Cannot convert Variant to {} Arrow DenseUnion: " ++ "variant offsets are not monotonic for discriminator {}", ++ format_name, std::to_string(global_discr)); ++ ends[global_discr] = column_offsets[idx] + 1; ++ } ++ ++ if (discriminator == ColumnVariant::NULL_DISCRIMINATOR || (null_bytemap && (*null_bytemap)[idx])) ++ status = type_ids_builder.Append(static_cast(num_variants)); ++ else ++ status = type_ids_builder.Append(static_cast(column.globalDiscriminatorByLocal(discriminator))); ++ ++ checkStatus(status, "type_ids", format_name); ++ } ++ ++ std::shared_ptr type_ids_array; ++ status = type_ids_builder.Finish(&type_ids_array); ++ checkStatus(status, "type_ids", format_name); ++ ++ ++ arrow::ArrayVector children; ++ for (size_t i = 0; i < column.getNumVariants(); ++i) ++ { ++ const auto & variant = column.getVariantPtrByGlobalDiscriminator(i); ++ ++ bool is_column_nullable = false; ++ auto arrow_type = getArrowType( ++ column_type.getVariant(i), ++ variant, ++ variant->getName(), ++ format_name, ++ settings, ++ &is_column_nullable); ++ ++ std::unique_ptr variant_array_builder; ++ status = MakeBuilder(arrow::default_memory_pool(), arrow_type, &variant_array_builder); ++ checkStatus(status, variant->getName(), format_name); ++ ++ if (ends[i] == 0) ++ { ++ auto empty_array = checkResult(arrow::MakeArrayOfNull(arrow_type, 0), variant->getName(), format_name); ++ children.push_back(empty_array); ++ } ++ else ++ { ++ std::shared_ptr variant_arrow_array = fillArrowArray( ++ variant->getName(), ++ variant, ++ column_type.getVariant(i), ++ nullptr, ++ variant_array_builder.get(), ++ format_name, ++ starts[i], ++ ends[i], ++ settings, ++ dictionary_values); ++ ++ children.push_back(variant_arrow_array); ++ } + } ++ children.push_back(std::make_shared(1)); ++ ++ arrow::Int32Builder offsets_builder; ++ ++ /// column_offsets should be sanitized because NULL_DISCRIMINATOR positions in ColumnVariant ++ /// makes offsets at these positions irrelevant (and they can have unspecified values), ++ /// but for arrow dense union they are pointing to an actual NULL array ++ auto to_arrow_offset = [&](const auto & tuple) -> int32_t ++ { ++ const auto & discriminator = boost::get<0>(tuple); ++ const auto & column_offset = boost::get<1>(tuple); ++ ++ if constexpr (std::tuple_size_v> == 3) ++ if (static_cast(boost::get<2>(tuple))) ++ return 0; ++ if (discriminator == ColumnVariant::NULL_DISCRIMINATOR) ++ return 0; ++ ++ const auto offset = column_offset - starts[column.globalDiscriminatorByLocal(discriminator)]; ++ if (offset > static_cast(std::numeric_limits::max())) ++ throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot build Arrow DenseUnion: offset {} is out of Int32 range", offset); ++ return static_cast(offset); ++ }; ++ ++ auto append_offsets = [&](Ts&&... args) ++ { ++ auto begin_it = boost::make_transform_iterator( ++ boost::make_zip_iterator(boost::make_tuple((args->begin() + start)...)), ++ to_arrow_offset ++ ); ++ auto end_it = boost::make_transform_iterator( ++ boost::make_zip_iterator(boost::make_tuple((args->begin() + start + size)...)), ++ to_arrow_offset ++ ); ++ return offsets_builder.AppendValues(begin_it, end_it); ++ }; ++ ++ if (null_bytemap) ++ status = append_offsets(&discriminators, &column_offsets, null_bytemap); ++ else ++ status = append_offsets(&discriminators, &column_offsets); ++ ++ checkStatus(status, "offsets", format_name); ++ std::shared_ptr offsets_array; ++ status = offsets_builder.Finish(&offsets_array); ++ checkStatus(status, "offsets", format_name); ++ ++ return checkResult(arrow::DenseUnionArray::Make(*type_ids_array, *offsets_array, children), "type_ids", format_name); + } + +- static void fillArrowArrayWithTupleColumnData( ++ ++ static std::shared_ptr buildArrowStructArrayWithTupleColumnData( + const String & column_name, +- ColumnPtr & column, ++ const ColumnPtr & column, + const DataTypePtr & column_type, + const PaddedPODArray * null_bytemap, + arrow::ArrayBuilder * array_builder, +@@ -372,51 +557,84 @@ namespace DB + + arrow::StructBuilder & builder = assert_cast(*array_builder); + ++ if (column_tuple->tupleSize() == 0) ++ { ++ for (size_t i = start; i != end; ++i) ++ checkStatus(builder.Append(), column->getName(), format_name); ++ return checkResult(builder.Finish(), column_name, format_name); ++ } ++ ++ arrow::ArrayVector children; ++ + for (size_t i = 0; i != column_tuple->tupleSize(); ++i) + { + ColumnPtr nested_column = column_tuple->getColumnPtr(i); +- fillArrowArray( +- column_name + "." + nested_names[i], ++ auto name = column_name + "." + nested_names[i]; ++ std::shared_ptr nested_arrow_array = fillArrowArray( ++ name, + nested_column, nested_types[i], null_bytemap, + builder.field_builder(static_cast(i)), + format_name, + start, end, + settings, + dictionary_values); +- } + +- for (size_t i = start; i != end; ++i) +- { +- auto status = builder.Append(); +- checkStatus(status, column->getName(), format_name); ++ children.push_back(nested_arrow_array); + } ++ ++ auto null_bitmap = nullBytemapToArrowBitmap(null_bytemap, column_name, format_name, start, end); ++ return checkResult(arrow::StructArray::Make(children, builder.type()->fields(), null_bitmap), column_name, format_name); + } + +- template +- static PaddedPODArray extractIndexes(ColumnPtr column, size_t start, size_t end, bool shift) ++ template ++ requires (std::integral && std::integral) ++ static PaddedPODArray extractIndexes(ColumnPtr column, size_t start, size_t end, bool shift) + { +- const PaddedPODArray & data = assert_cast *>(column.get())->getData(); +- PaddedPODArray result; ++ const PaddedPODArray & data = assert_cast *>(column.get())->getData(); ++ PaddedPODArray result; + result.reserve(end - start); ++ ++ auto checked_cast = [](From value) -> To ++ { ++ constexpr bool always_safe = ++ // same signedness, destination has at least as many value bits ++ (std::numeric_limits::is_signed == std::numeric_limits::is_signed ++ && std::numeric_limits::digits >= std::numeric_limits::digits) ++ // unsigned -> signed is safe only if destination has strictly more value bits ++ || (!std::numeric_limits::is_signed ++ && std::numeric_limits::is_signed ++ && std::numeric_limits::digits > std::numeric_limits::digits); ++ ++ if constexpr (always_safe) ++ return static_cast(value); ++ ++ To converted{}; ++ if (!accurate::convertNumeric(value, converted)) ++ throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot convert index {} to target type without overflow", std::to_string(value)); ++ return converted; ++ }; ++ + if (shift) +- std::transform(data.begin() + start, data.begin() + end, std::back_inserter(result), [](T value) { return Int64(value) - 1; }); ++ std::transform(data.begin() + start, data.begin() + end, std::back_inserter(result), [&](From value) { return checked_cast(value) - 1; }); + else +- std::transform(data.begin() + start, data.begin() + end, std::back_inserter(result), [](T value) { return Int64(value); }); ++ std::transform(data.begin() + start, data.begin() + end, std::back_inserter(result), checked_cast); + return result; + } + +- static PaddedPODArray extractIndexes(ColumnPtr column, size_t start, size_t end, bool shift) ++ template ++ requires std::integral ++ static PaddedPODArray extractIndexes(ColumnPtr column, size_t start, size_t end, bool shift) + { + switch (column->getDataType()) + { + case TypeIndex::UInt8: +- return extractIndexes(column, start, end, shift); ++ return extractIndexes(column, start, end, shift); + case TypeIndex::UInt16: +- return extractIndexes(column, start, end, shift); ++ return extractIndexes(column, start, end, shift); + case TypeIndex::UInt32: +- return extractIndexes(column, start, end, shift); ++ return extractIndexes(column, start, end, shift); + case TypeIndex::UInt64: +- return extractIndexes(column, start, end, shift); ++ return extractIndexes(column, start, end, shift); + default: + throw Exception(ErrorCodes::LOGICAL_ERROR, "Indexes column must be ColumnUInt, got {}.", column->getName()); + } +@@ -474,7 +692,7 @@ namespace DB + /// We can use Int32/UInt32/Int64/UInt64 type for indexes. + const auto * indexes_int32_type = typeid_cast(dict_indexes_arrow_type.get()); + const auto * indexes_uint32_type = typeid_cast(dict_indexes_arrow_type.get()); +- const auto * indexes_int64_type = typeid_cast(dict_indexes_arrow_type.get()); ++ const auto * indexes_int64_type = typeid_cast(dict_indexes_arrow_type.get()); + if ((indexes_int32_type && dict_size > INT32_MAX) || (indexes_uint32_type && dict_size > UINT32_MAX) || (indexes_int64_type && dict_size > INT64_MAX)) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, +@@ -482,10 +700,80 @@ namespace DB + " resulting dictionary size exceeds the max value of index type {}", dict_indexes_arrow_type->name()); + } + ++ static std::shared_ptr buildArrowListArrayWithArrayColumnData( ++ const String & column_name, ++ const ColumnPtr & column, ++ const DataTypePtr & column_type, ++ const PaddedPODArray * null_bytemap, ++ arrow::ArrayBuilder * array_builder, ++ String format_name, ++ size_t start, ++ size_t end, ++ const CHColumnToArrowColumn::Settings & settings, ++ std::unordered_map & dictionary_values) ++ { ++ const auto * column_array = assert_cast(column.get()); ++ const auto * type_array = assert_cast(column_type.get()); ++ ++ const auto column_offsets = assert_cast(column_array->getOffsetsColumn()).getPtr(); ++ size_t offsets_start = start > 0 ? start - 1 : 0; ++ size_t offsets_view_start = start > 0 ? 1 : 0; ++ auto offsets = extractIndexes(column_offsets, offsets_start, end, false); ++ size_t values_start = start == 0 ? 0 : offsets[0]; ++ size_t values_end = offsets.empty() ? values_start : offsets.back(); ++ ++ arrow::ListBuilder & builder = assert_cast(*array_builder); ++ ++ auto data_array = fillArrowArray(column_name, column_array->getDataPtr(), type_array->getNestedType(), nullptr, builder.value_builder(), format_name, values_start, values_end, settings, dictionary_values); ++ ++ arrow::Status status; ++ arrow::Int32Builder offsets_builder; ++ status = offsets_builder.Append(0); ++ checkStatus(status, column_name, format_name); ++ for (size_t i = offsets_view_start; i < offsets.size(); ++i) ++ { ++ status = offsets_builder.Append(static_cast(offsets[i] - values_start)); ++ checkStatus(status, column_name, format_name); ++ } ++ ++ std::shared_ptr offsets_array; ++ status = offsets_builder.Finish(&offsets_array); ++ checkStatus(status, column_name, format_name); ++ ++ auto null_bitmap = nullBytemapToArrowBitmap(null_bytemap, column_name, format_name, start, end); ++ return checkResult(arrow::ListArray::FromArrays(*offsets_array, *data_array, arrow::default_memory_pool(), null_bitmap), column_name, format_name); ++ } ++ ++ static std::shared_ptr buildArrowMapArrayWithMapColumnData( ++ const String & column_name, ++ const ColumnPtr & column, ++ const DataTypePtr & column_type, ++ const PaddedPODArray * null_bytemap, ++ arrow::ArrayBuilder * array_builder, ++ String format_name, ++ size_t start, ++ size_t end, ++ const CHColumnToArrowColumn::Settings & settings, ++ std::unordered_map & dictionary_values) ++ { ++ const auto * column_map = assert_cast(column.get()); ++ auto nested_column = column_map->getNestedColumnPtr(); ++ const auto * type_map = assert_cast(column_type.get()); ++ const DataTypePtr & nested_type = type_map->getNestedType(); ++ ++ auto * map_builder = assert_cast(array_builder); ++ auto builder = checkResult(arrow::MakeBuilder(arrow::list(map_builder->value_builder()->type())), column_name, format_name); ++ ++ auto list = buildArrowListArrayWithArrayColumnData(column_name, nested_column, nested_type, null_bytemap, builder.get(), format_name, start, end, settings, dictionary_values); ++ auto * list_array = assert_cast(list.get()); ++ ++ return std::make_shared(map_builder->type(), list_array->length(), list_array->value_offsets(), list_array->values(), list_array->null_bitmap()); ++ } ++ + template + static void fillArrowArrayWithLowCardinalityColumnDataImpl( + const String & column_name, +- ColumnPtr & column, ++ const ColumnPtr & column, + const DataTypePtr & column_type, + const PaddedPODArray *, + arrow::ArrayBuilder * array_builder, +@@ -533,10 +821,7 @@ namespace DB + + auto dict_column = dynamic_cast(*dict_values).getNestedNotNullableColumn(); + const auto & dict_type = removeNullable(assert_cast(column_type.get())->getDictionaryType()); +- fillArrowArray(column_name, dict_column, dict_type, nullptr, values_builder.get(), format_name, is_nullable, dict_column->size(), settings, dictionary_values); +- std::shared_ptr arrow_dict_array; +- status = values_builder->Finish(&arrow_dict_array); +- checkStatus(status, column->getName(), format_name); ++ std::shared_ptr arrow_dict_array = fillArrowArray(column_name, dict_column, dict_type, nullptr, values_builder.get(), format_name, is_nullable, dict_column->size(), settings, dictionary_values); + + status = builder->InsertMemoValues(*arrow_dict_array); + checkStatus(status, column->getName(), format_name); +@@ -567,7 +852,7 @@ namespace DB + + static void fillArrowArrayWithLowCardinalityColumnData( + const String & column_name, +- ColumnPtr & column, ++ const ColumnPtr & column, + const DataTypePtr & column_type, + const PaddedPODArray * null_bytemap, + arrow::ArrayBuilder * array_builder, +@@ -873,9 +1158,9 @@ namespace DB + checkStatus(status, write_column->getName(), format_name); + } + +- static void fillArrowArray( ++ static std::shared_ptr fillArrowArray( + const String & column_name, +- ColumnPtr & column, ++ ColumnPtr column, + const DataTypePtr & column_type, + const PaddedPODArray * null_bytemap, + arrow::ArrayBuilder * array_builder, +@@ -885,6 +1170,11 @@ namespace DB + const CHColumnToArrowColumn::Settings & settings, + std::unordered_map & dictionary_values) + { ++ std::shared_ptr arrow_array; ++ ++ column = column->convertToFullColumnIfConst(); ++ column = column->convertToFullColumnIfReplicated(); ++ + switch (column_type->getTypeId()) + { + case TypeIndex::Nullable: +@@ -894,12 +1184,12 @@ namespace DB + DataTypePtr nested_type = assert_cast(column_type.get())->getNestedType(); + const ColumnPtr & null_column = column_nullable->getNullMapColumnPtr(); + const PaddedPODArray & bytemap = assert_cast &>(*null_column).getData(); +- fillArrowArray(column_name, nested_column, nested_type, &bytemap, array_builder, format_name, start, end, settings, dictionary_values); ++ arrow_array = fillArrowArray(column_name, nested_column, nested_type, &bytemap, array_builder, format_name, start, end, settings, dictionary_values); + break; + } + case TypeIndex::String: + { +- if (settings.output_string_as_string) ++ if (settings.output_string_as_string && !array_builder->type()->Equals(arrow::binary())) + fillArrowArrayWithStringColumnData(column, null_bytemap, format_name, array_builder, start, end); + else + fillArrowArrayWithStringColumnData(column, null_bytemap, format_name, array_builder, start, end); +@@ -931,19 +1221,32 @@ namespace DB + fillArrowArrayWithDate32ColumnData(column, null_bytemap, format_name, array_builder, start, end); + break; + case TypeIndex::Array: +- fillArrowArrayWithArrayColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); ++ arrow_array = buildArrowListArrayWithArrayColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); + break; + case TypeIndex::Tuple: +- fillArrowArrayWithTupleColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); ++ arrow_array = buildArrowStructArrayWithTupleColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); + break; + case TypeIndex::LowCardinality: + fillArrowArrayWithLowCardinalityColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); + break; + case TypeIndex::Map: + { +- ColumnPtr column_array = assert_cast(column.get())->getNestedColumnPtr(); +- DataTypePtr array_type = assert_cast(column_type.get())->getNestedType(); +- fillArrowArrayWithArrayColumnData(column_name, column_array, array_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); ++ arrow_array = buildArrowMapArrayWithMapColumnData(column_name, column, column_type, null_bytemap, array_builder, format_name, start, end, settings, dictionary_values); ++ break; ++ } ++ case TypeIndex::Variant: ++ { ++ const auto & column_variant = assert_cast(*column); ++ const auto & column_variant_type = assert_cast(*column_type); ++ arrow_array = buildArrowDenseUnionArrayWithVariantColumnData( ++ column_variant, ++ column_variant_type, ++ null_bytemap, ++ format_name, ++ start, ++ end, ++ settings, ++ dictionary_values); + break; + } + case TypeIndex::Decimal32: +@@ -1014,8 +1317,18 @@ namespace DB + break; + } + default: +- throw Exception(ErrorCodes::UNKNOWN_TYPE, "Internal type '{}' of a column '{}' is not supported for conversion into {} data format.", column_type->getFamilyName(), column_name, format_name); ++ if (!settings.output_unsupported_types_as_binary) ++ throw Exception(ErrorCodes::UNKNOWN_TYPE, "Internal type '{}' of a column '{}' is not supported for conversion into {} data format.", column_type->getFamilyName(), column_name, format_name); ++ fillArrowArrayWithRawColumnData(column, null_bytemap, format_name, array_builder, start, end); + } ++ ++ if (!arrow_array) ++ { ++ auto status = array_builder->Finish(&arrow_array); ++ checkStatus(status, column->getName(), format_name); ++ } ++ ++ return arrow_array; + } + + static std::shared_ptr getArrowTypeForLowCardinalityIndexes(ColumnPtr indexes_column, const CHColumnToArrowColumn::Settings & settings) +@@ -1062,12 +1375,18 @@ namespace DB + } + + static std::shared_ptr getArrowType( +- DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable, bool for_builder = false) ++ DataTypePtr column_type, ColumnPtr column, const std::string & column_name, const std::string & format_name, const CHColumnToArrowColumn::Settings & settings, bool * out_is_column_nullable, bool for_builder) + { ++ if (column) ++ { ++ column = column->convertToFullColumnIfConst(); ++ column = column->convertToFullColumnIfReplicated(); ++ } ++ + if (column_type->isNullable()) + { + DataTypePtr nested_type = assert_cast(column_type.get())->getNestedType(); +- ColumnPtr nested_column = assert_cast(column.get())->getNestedColumnPtr(); ++ ColumnPtr nested_column = column ? assert_cast(column.get())->getNestedColumnPtr() : nullptr; + auto arrow_type = getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable, for_builder); + *out_is_column_nullable = true; + return arrow_type; +@@ -1101,7 +1420,7 @@ namespace DB + if (isArray(column_type)) + { + auto nested_type = assert_cast(column_type.get())->getNestedType(); +- auto nested_column = assert_cast(column.get())->getDataPtr(); ++ auto nested_column = column ? assert_cast(column.get())->getDataPtr() : nullptr; + bool is_item_nullable = false; + auto nested_arrow_type = getArrowType(nested_type, nested_column, column_name, format_name, settings, &is_item_nullable, for_builder); + return arrow::list(std::make_shared("item", nested_arrow_type, is_item_nullable)); +@@ -1112,12 +1431,12 @@ namespace DB + const auto & tuple_type = assert_cast(column_type.get()); + const auto & nested_types = tuple_type->getElements(); + const auto & nested_names = tuple_type->getElementNames(); +- const auto * tuple_column = assert_cast(column.get()); ++ const auto * tuple_column = column ? assert_cast(column.get()) : nullptr; + std::vector> nested_fields; + for (size_t i = 0; i != nested_types.size(); ++i) + { + bool is_field_nullable = false; +- auto nested_arrow_type = getArrowType(nested_types[i], tuple_column->getColumnPtr(i), nested_names[i], format_name, settings, &is_field_nullable, for_builder); ++ auto nested_arrow_type = getArrowType(nested_types[i], tuple_column ? tuple_column->getColumnPtr(i) : nullptr, nested_names[i], format_name, settings, &is_field_nullable, for_builder); + nested_fields.push_back(std::make_shared(nested_names[i], nested_arrow_type, is_field_nullable)); + } + return arrow::struct_(nested_fields); +@@ -1126,12 +1445,23 @@ namespace DB + if (column_type->lowCardinality()) + { + auto nested_type = assert_cast(column_type.get())->getDictionaryType(); +- const auto * lc_column = assert_cast(column.get()); +- const auto & nested_column = lc_column->getDictionary().getNestedColumn(); +- const auto & indexes_column = lc_column->getIndexesPtr(); +- return arrow::dictionary( +- getArrowTypeForLowCardinalityIndexes(indexes_column, settings), +- getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable, for_builder)); ++ if (column) ++ { ++ const auto * lc_column = assert_cast(column.get()); ++ const auto & nested_column = lc_column->getDictionary().getNestedColumn(); ++ const auto & indexes_column = lc_column->getIndexesPtr(); ++ return arrow::dictionary( ++ getArrowTypeForLowCardinalityIndexes(indexes_column, settings), ++ getArrowType(nested_type, nested_column, column_name, format_name, settings, out_is_column_nullable, for_builder)); ++ } ++ else ++ { ++ auto index_arrow_type = settings.use_64_bit_indexes_for_dictionary ? ++ (settings.use_signed_indexes_for_dictionary ? arrow::int64() : arrow::uint64()) : ++ (settings.use_signed_indexes_for_dictionary ? arrow::int32() : arrow::uint32()); ++ auto arrow_type = getArrowType(nested_type, nullptr, column_name, format_name, settings, out_is_column_nullable, for_builder); ++ return arrow::dictionary(index_arrow_type, arrow_type); ++ } + } + + if (isMap(column_type)) +@@ -1139,12 +1469,19 @@ namespace DB + const auto * map_type = assert_cast(column_type.get()); + const auto & key_type = map_type->getKeyType(); + const auto & val_type = map_type->getValueType(); +- const auto & columns = assert_cast(column.get())->getNestedData().getColumns(); ++ ColumnPtr key_column; ++ ColumnPtr value_column; ++ if (column) ++ { ++ const auto & columns = assert_cast(column.get())->getNestedData().getColumns(); ++ key_column = columns[0]; ++ value_column = columns[1]; ++ } + +- bool _is_key_nullable = false; +- auto key_arrow_type = getArrowType(key_type, columns[0], column_name, format_name, settings, &_is_key_nullable, for_builder); ++ bool is_key_nullable = false; ++ auto key_arrow_type = getArrowType(key_type, key_column, column_name, format_name, settings, &is_key_nullable, for_builder); + bool is_val_nullable = false; +- auto val_arrow_type = getArrowType(val_type, columns[1], column_name, format_name, settings, &is_val_nullable, for_builder); ++ auto val_arrow_type = getArrowType(val_type, value_column, column_name, format_name, settings, &is_val_nullable, for_builder); + + return arrow::map( + key_arrow_type, +@@ -1175,6 +1512,50 @@ namespace DB + if (isIPv4(column_type)) + return arrow::uint32(); + ++ if (isVariant(column_type)) ++ { ++ const auto * column_variant = column ? &assert_cast(*column) : nullptr; ++ const auto & column_variant_type = assert_cast(*column_type); ++ ++ auto size = column_variant_type.getVariants().size(); ++ if (size > static_cast(std::numeric_limits::max())) ++ { ++ throw Exception( ++ ErrorCodes::ILLEGAL_COLUMN, ++ "Cannot convert Variant with {} nested types to {} Arrow DenseUnion: maximum supported is {} ", ++ size, ++ format_name, ++ static_cast(std::numeric_limits::max())); ++ } ++ ++ arrow::FieldVector fields; ++ ++ for (size_t i = 0; i < size; ++i) ++ { ++ const auto variant = column_variant ? column_variant->getVariantPtrByGlobalDiscriminator(i) : nullptr; ++ ++ bool is_column_nullable = false; ++ auto arrow_type = getArrowType( ++ column_variant_type.getVariant(i), ++ variant, ++ variant ? variant->getName() : "variant", ++ format_name, ++ settings, ++ &is_column_nullable, ++ for_builder); ++ ++ std::string field_name = column_variant_type.getVariant(i)->getFamilyName(); ++ fields.push_back(std::make_shared(field_name, arrow_type, is_column_nullable)); ++ } ++ ++ /// Variant in CH is slightly different than in arrow - it can indicate null value by having ColumnVariant::NULL_DISCRIMINATOR ++ /// in discriminators instead of using nullable type - because of this we need to introduce additional ++ /// null array (having a single null value) to have these null values to refer to ++ fields.push_back(std::make_shared("NULL", arrow::null(), false)); ++ ++ return arrow::dense_union(fields); ++ } ++ + if (isInterval(column_type)) + { + const auto * interval_type = assert_cast(column_type.get()); +@@ -1187,6 +1568,7 @@ namespace DB + default: return arrow::int64(); + } + } ++ + if (isUUID(column_type)) + return for_builder ? arrow::fixed_size_binary(sizeof(UUID)) : std::make_shared(); + +@@ -1203,49 +1585,22 @@ namespace DB + return arrow_type_it->second; + } + +- throw Exception(ErrorCodes::UNKNOWN_TYPE, +- "The type '{}' of a column '{}' is not supported for conversion into {} data format.", +- column_type->getName(), column_name, format_name); ++ if (!settings.output_unsupported_types_as_binary) ++ throw Exception(ErrorCodes::UNKNOWN_TYPE, ++ "The type '{}' of a column '{}' is not supported for conversion into {} data format.", ++ column_type->getName(), column_name, format_name); ++ return arrow::binary(); + } + +- CHColumnToArrowColumn::CHColumnToArrowColumn(const Block & header, const std::string & format_name_, const Settings & settings_) +- : CHColumnToArrowColumn(header.getColumnsWithTypeAndName(), format_name_, settings_) ++ std::shared_ptr CHColumnToArrowColumn::calculateArrowSchema( ++ const ColumnsWithTypeAndName & header_columns, ++ const std::string & format_name, ++ const Chunk * chunk, ++ const Settings & settings, ++ std::optional columns_num, ++ const std::optional> & column_to_field_id ++ ) + { +- } +- +- CHColumnToArrowColumn::CHColumnToArrowColumn( +- const ColumnsWithTypeAndName & header_columns_, const std::string & format_name_, const Settings & settings_) +- : format_name(format_name_) +- , settings(settings_) +- { +- if (settings.low_cardinality_as_dictionary) +- { +- header_columns = header_columns_; +- return; +- } +- header_columns.reserve(header_columns_.size()); +- for (auto column : header_columns_) +- { +- column.type = recursiveRemoveLowCardinality(column.type); +- column.column = recursiveRemoveLowCardinality(column.column); +- header_columns.emplace_back(std::move(column)); +- } +- } +- +- std::unique_ptr CHColumnToArrowColumn::clone(bool copy_arrow_schema) const +- { +- auto res = std::make_unique(header_columns, format_name, settings); +- if (copy_arrow_schema) +- res->arrow_schema = arrow_schema; +- return res; +- } +- +- void CHColumnToArrowColumn::initializeArrowSchema( +- const Chunk * chunk, std::optional columns_num, const std::optional> & column_to_field_id) +- { +- if (arrow_schema) +- return; +- + if (!columns_num) + columns_num = header_columns.size(); + +@@ -1255,14 +1610,19 @@ namespace DB + for (size_t column_i = 0; column_i < *columns_num; ++column_i) + { + const ColumnWithTypeAndName & header_column = header_columns[column_i]; ++ auto column_type = header_column.type; + auto column = chunk ? chunk->getColumns()[column_i] : header_column.column; + + if (!settings.low_cardinality_as_dictionary) +- column = recursiveRemoveLowCardinality(column); ++ { ++ column_type = recursiveRemoveLowCardinality(column_type); ++ if (column) ++ column = recursiveRemoveLowCardinality(column); ++ } + + bool is_column_nullable = false; + auto arrow_type = getArrowType( +- header_column.type, ++ column_type, + column, + header_column.name, + format_name, +@@ -1293,27 +1653,26 @@ namespace DB + arrow_fields.emplace_back(std::make_shared(header_column.name, arrow_type, is_column_nullable)); + } + +- arrow_schema = std::make_shared(arrow_fields); ++ return std::make_shared(arrow_fields); + } + +- std::shared_ptr CHColumnToArrowColumn::getArrowSchema() const +- { +- if (!arrow_schema) +- throw Exception(ErrorCodes::LOGICAL_ERROR, "Arrow schema is not initialized"); +- return arrow_schema; +- } + +- void CHColumnToArrowColumn::chChunkToArrowTable( +- std::shared_ptr & res, ++ std::shared_ptr CHColumnToArrowColumn::calculateArrowTable( ++ const ColumnsWithTypeAndName & header_columns, ++ const std::string & format_name, + const std::vector & chunks, ++ const Settings & settings, + size_t columns_num, +- const std::optional> & column_to_field_id) ++ std::shared_ptr schema, ++ std::unordered_map * cached_dictionary_values) + { +- std::vector table_data(columns_num); ++ /// Map {column name : arrow dictionary}. ++ /// To avoid converting dictionary from LowCardinality to Arrow ++ /// Dictionary every chunk we save it and reuse. ++ std::unordered_map local_dictionary_values; ++ std::unordered_map & dictionary_values = cached_dictionary_values ? *cached_dictionary_values : local_dictionary_values; + +- /// We use the first chunk to initialize the arrow schema. +- const Chunk * chunk_to_initialize_schema = chunks.empty() ? nullptr : chunks.data(); +- initializeArrowSchema(chunk_to_initialize_schema, columns_num, column_to_field_id); ++ std::vector table_data(columns_num); + + for (const auto & chunk : chunks) + { +@@ -1321,24 +1680,28 @@ namespace DB + for (size_t column_i = 0; column_i < columns_num; ++column_i) + { + const ColumnWithTypeAndName & header_column = header_columns[column_i]; ++ auto column_type = header_column.type; + auto column = chunk.getColumns()[column_i]; + + if (!settings.low_cardinality_as_dictionary) ++ { + column = recursiveRemoveLowCardinality(column); ++ column_type = recursiveRemoveLowCardinality(column_type); ++ } + + // Generate the unwrapped builder schema (safe for MakeBuilder) + bool is_column_nullable = false; + auto builder_type = getArrowType( +- header_column.type, column, header_column.name, format_name, settings, &is_column_nullable, true /* for_builder */); ++ column_type, column, header_column.name, format_name, settings, &is_column_nullable, true /* for_builder */); + + std::unique_ptr array_builder; + arrow::Status status = MakeBuilder(arrow::default_memory_pool(), builder_type, &array_builder); + checkStatus(status, column->getName(), format_name); + +- fillArrowArray( ++ std::shared_ptr arrow_array = fillArrowArray( + header_column.name, + column, +- header_column.type, ++ column_type, + nullptr, + array_builder.get(), + format_name, +@@ -1347,18 +1710,10 @@ namespace DB + settings, + dictionary_values); + +- std::shared_ptr arrow_array; +- status = array_builder->Finish(&arrow_array); +- checkStatus(status, column->getName(), format_name); +- + // Zero-copy cast to the extension-rich schema (handles infinite nesting) +- auto target_type = arrow_schema->field(static_cast(column_i))->type(); ++ auto target_type = schema->field(static_cast(column_i))->type(); + if (!arrow_array->type()->Equals(*target_type)) +- { +- auto view_result = arrow_array->View(target_type); +- checkStatus(view_result.status(), column->getName(), format_name); +- arrow_array = view_result.ValueOrDie(); +- } ++ arrow_array = checkResult(arrow_array->View(target_type), column->getName(), format_name); + + table_data.at(column_i).emplace_back(std::move(arrow_array)); + } +@@ -1369,7 +1724,67 @@ namespace DB + for (size_t column_i = 0; column_i < columns_num; ++column_i) + columns.emplace_back(std::make_shared(table_data.at(column_i))); + +- res = arrow::Table::Make(arrow_schema, columns); ++ return arrow::Table::Make(schema, columns); ++ } ++ ++ CHColumnToArrowColumn::CHColumnToArrowColumn(const Block & header, const std::string & format_name_, const Settings & settings_) ++ : CHColumnToArrowColumn(header.getColumnsWithTypeAndName(), format_name_, settings_) ++ { ++ } ++ ++ CHColumnToArrowColumn::CHColumnToArrowColumn( ++ const ColumnsWithTypeAndName & header_columns_, const std::string & format_name_, const Settings & settings_) ++ : format_name(format_name_) ++ , settings(settings_) ++ { ++ if (settings.low_cardinality_as_dictionary) ++ { ++ header_columns = header_columns_; ++ return; ++ } ++ header_columns.reserve(header_columns_.size()); ++ for (auto column : header_columns_) ++ { ++ column.type = recursiveRemoveLowCardinality(column.type); ++ column.column = recursiveRemoveLowCardinality(column.column); ++ header_columns.emplace_back(std::move(column)); ++ } ++ } ++ ++ std::unique_ptr CHColumnToArrowColumn::clone(bool copy_arrow_schema) const ++ { ++ auto res = std::make_unique(header_columns, format_name, settings); ++ if (copy_arrow_schema) ++ res->arrow_schema = arrow_schema; ++ return res; ++ } ++ ++ void CHColumnToArrowColumn::initializeArrowSchema( ++ const Chunk * chunk, std::optional columns_num, const std::optional> & column_to_field_id) ++ { ++ if (arrow_schema) ++ return; ++ arrow_schema = calculateArrowSchema(header_columns, format_name, chunk, settings, columns_num, column_to_field_id); ++ } ++ ++ std::shared_ptr CHColumnToArrowColumn::getArrowSchema() const ++ { ++ if (!arrow_schema) ++ throw Exception(ErrorCodes::LOGICAL_ERROR, "Arrow schema is not initialized"); ++ return arrow_schema; ++ } ++ ++ void CHColumnToArrowColumn::chChunkToArrowTable( ++ std::shared_ptr & res, ++ const std::vector & chunks, ++ size_t columns_num, ++ const std::optional> & column_to_field_id) ++ { ++ /// We use the first chunk to initialize the arrow schema. ++ const Chunk * chunk_to_initialize_schema = chunks.empty() ? nullptr : chunks.data(); ++ initializeArrowSchema(chunk_to_initialize_schema, columns_num, column_to_field_id); ++ ++ res = calculateArrowTable(header_columns, format_name, chunks, settings, columns_num, arrow_schema, &dictionary_values); + } + } + From 4dccbce9dfb174c622d7a6cd95a7aba252e20e9f Mon Sep 17 00:00:00 2001 From: Andrey Zvonov Date: Fri, 15 May 2026 21:29:11 +0200 Subject: [PATCH 4/4] Resolve conflicts in cherry-pick of #101272 --- src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp index 7a41c4901969..ec0aada93897 100644 --- a/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp +++ b/src/Processors/Formats/Impl/ArrowColumnToCHColumn.cpp @@ -1425,10 +1425,6 @@ static ColumnWithTypeAndName readColumnFromArrowColumn( arrow_column->type()->id() != arrow::Type::LARGE_LIST && arrow_column->type()->id() != arrow::Type::FIXED_SIZE_LIST && arrow_column->type()->id() != arrow::Type::MAP && -<<<<<<< HEAD - arrow_column->type()->id() != arrow::Type::STRUCT && -======= ->>>>>>> fc17de3cb80 (Merge pull request #101272 from nihalzp/support-arrow-orc-nullable-tuple) arrow_column->type()->id() != arrow::Type::DICTIONARY) { DataTypePtr nested_type_hint;