diff --git a/native/Cargo.lock b/native/Cargo.lock index 465454adc5..af1a096845 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1837,6 +1837,7 @@ dependencies = [ "crc32fast", "criterion", "datafusion", + "datafusion-comet-jni-bridge", "datafusion-comet-objectstore-hdfs", "datafusion-comet-proto", "datafusion-comet-spark-expr", @@ -1869,13 +1870,11 @@ dependencies = [ "procfs", "prost", "rand 0.10.0", - "regex", "reqwest", "serde_json", "simd-adler32", "snap", "tempfile", - "thiserror 2.0.18", "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", @@ -1884,6 +1883,17 @@ dependencies = [ "zstd", ] +[[package]] +name = "datafusion-comet-common" +version = "0.14.0" +dependencies = [ + "arrow", + "datafusion", + "serde", + "serde_json", + "thiserror 2.0.18", +] + [[package]] name = "datafusion-comet-fs-hdfs3" version = "0.1.12" @@ -1899,6 +1909,24 @@ dependencies = [ "uuid", ] +[[package]] +name = "datafusion-comet-jni-bridge" +version = "0.14.0" +dependencies = [ + "arrow", + "assertables", + "datafusion", + "datafusion-comet-common", + "jni", + "lazy_static", + "once_cell", + "parquet", + "paste", + "prost", + "regex", + "thiserror 2.0.18", +] + [[package]] name = "datafusion-comet-objectstore-hdfs" version = "0.14.0" @@ -1931,14 +1959,13 @@ dependencies = [ "chrono-tz", "criterion", "datafusion", + "datafusion-comet-common", "futures", "hex", "num", "rand 0.10.0", "regex", - "serde", "serde_json", - "thiserror 2.0.18", "tokio", "twox-hash", ] diff --git a/native/Cargo.toml b/native/Cargo.toml index 7979978a31..fe9d524ecc 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [workspace] -default-members = ["core", "spark-expr", "proto"] -members = ["core", "spark-expr", "proto", "hdfs", "fs-hdfs"] +default-members = ["core", "spark-expr", "common", "proto", "jni-bridge"] +members = ["core", "spark-expr", "common", "proto", "jni-bridge", "hdfs", "fs-hdfs"] resolver = "2" [workspace.package] @@ -43,6 +43,8 @@ datafusion-datasource = { version = "52.2.0" } datafusion-physical-expr-adapter = { version = "52.2.0" } datafusion-spark = { version = "52.2.0" } datafusion-comet-spark-expr = { path = "spark-expr" } +datafusion-comet-common = { path = "common" } +datafusion-comet-jni-bridge = { path = "jni-bridge" } datafusion-comet-proto = { path = "proto" } chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.10" } diff --git a/native/common/Cargo.toml b/native/common/Cargo.toml new file mode 100644 index 0000000000..3bbc44856e --- /dev/null +++ b/native/common/Cargo.toml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-comet-common" +description = "Apache DataFusion Comet: common types shared across crates" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +publish = false + +[dependencies] +arrow = { workspace = true } +datafusion = { workspace = true } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = { workspace = true } + +[lib] +name = "datafusion_comet_common" +path = "src/lib.rs" diff --git a/native/common/src/error.rs b/native/common/src/error.rs new file mode 100644 index 0000000000..e36f069ac2 --- /dev/null +++ b/native/common/src/error.rs @@ -0,0 +1,842 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::error::ArrowError; +use datafusion::common::DataFusionError; +use std::sync::Arc; + +#[derive(thiserror::Error, Debug, Clone)] +pub enum SparkError { + // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastInvalidValue { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + + #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")] + NumericOutOfRange { value: String }, + + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastOverFlow { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] + CannotParseDecimal, + + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntegralDivideOverflow, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DecimalSumOverflow, + + #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DivideByZero, + + #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + RemainderByZero, + + #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] + IntervalDividedByZero, + + #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + BinaryArithmeticOverflow { + value1: String, + symbol: String, + value2: String, + function_name: String, + }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithSuggestion { function_name: String }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithoutSuggestion, + + #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] + DatetimeOverflow, + + #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidArrayIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidElementAtIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] + InvalidBitmapPosition { + bit_position: i64, + bitmap_num_bytes: i64, + bitmap_num_bits: i64, + }, + + #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] + InvalidIndexOfZero, + + #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] + DuplicatedMapKey { key: String }, + + #[error("[NULL_MAP_KEY] Cannot use null as map key.")] + NullMapKey, + + #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] + MapKeyValueDiffSizes, + + #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] + ExceedMapSizeLimit { size: i32, max_size: i32 }, + + #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] + CollectionSizeLimitExceeded { + num_elements: i64, + max_elements: i64, + }, + + #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] + NotNullAssertViolation { field_name: String }, + + #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] + ValueIsNull { field_name: String, row_index: i32 }, + + #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] + CannotParseTimestamp { + message: String, + suggested_func: String, + }, + + #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidFractionOfSecond { value: f64 }, + + #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] + InvalidUtf8String { hex_string: String }, + + #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] + UnexpectedPositiveValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] + UnexpectedNegativeValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] + InvalidRegexGroupIndex { + function_name: String, + group_count: i32, + group_index: i32, + }, + + #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] + DatatypeCannotOrder { data_type: String }, + + #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] + ScalarSubqueryTooManyRows, + + #[error("ArrowError: {0}.")] + Arrow(Arc), + + #[error("InternalError: {0}.")] + Internal(String), +} + +impl SparkError { + /// Serialize this error to JSON format for JNI transfer + pub fn to_json(&self) -> String { + let error_class = self.error_class().unwrap_or(""); + + // Create a JSON structure with errorType, errorClass, and params + match serde_json::to_string(&serde_json::json!({ + "errorType": self.error_type_name(), + "errorClass": error_class, + "params": self.params_as_json(), + })) { + Ok(json) => json, + Err(e) => { + // Fallback if serialization fails + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + } + } + } + + /// Get the error type name for JSON serialization + pub(crate) fn error_type_name(&self) -> &'static str { + match self { + SparkError::CastInvalidValue { .. } => "CastInvalidValue", + SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", + SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", + SparkError::CastOverFlow { .. } => "CastOverFlow", + SparkError::CannotParseDecimal => "CannotParseDecimal", + SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", + SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", + SparkError::DecimalSumOverflow => "DecimalSumOverflow", + SparkError::DivideByZero => "DivideByZero", + SparkError::RemainderByZero => "RemainderByZero", + SparkError::IntervalDividedByZero => "IntervalDividedByZero", + SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + "IntervalArithmeticOverflowWithSuggestion" + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + "IntervalArithmeticOverflowWithoutSuggestion" + } + SparkError::DatetimeOverflow => "DatetimeOverflow", + SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", + SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", + SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", + SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", + SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", + SparkError::NullMapKey => "NullMapKey", + SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", + SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", + SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", + SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", + SparkError::ValueIsNull { .. } => "ValueIsNull", + SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", + SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", + SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", + SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", + SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", + SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", + SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", + SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", + SparkError::Arrow(_) => "Arrow", + SparkError::Internal(_) => "Internal", + } + } + + /// Extract parameters as JSON value + pub(crate) fn params_as_json(&self) -> serde_json::Value { + match self { + SparkError::CastInvalidValue { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::NumericValueOutOfRange { + value, + precision, + scale, + } => { + serde_json::json!({ + "value": value, + "precision": precision, + "scale": scale, + }) + } + SparkError::NumericOutOfRange { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::CastOverFlow { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::ArithmeticOverflow { from_type } => { + serde_json::json!({ + "fromType": from_type, + }) + } + SparkError::BinaryArithmeticOverflow { + value1, + symbol, + value2, + function_name, + } => { + serde_json::json!({ + "value1": value1, + "symbol": symbol, + "value2": value2, + "functionName": function_name, + }) + } + SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { + serde_json::json!({ + "functionName": function_name, + }) + } + SparkError::InvalidArrayIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidElementAtIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidBitmapPosition { + bit_position, + bitmap_num_bytes, + bitmap_num_bits, + } => { + serde_json::json!({ + "bitPosition": bit_position, + "bitmapNumBytes": bitmap_num_bytes, + "bitmapNumBits": bitmap_num_bits, + }) + } + SparkError::DuplicatedMapKey { key } => { + serde_json::json!({ + "key": key, + }) + } + SparkError::ExceedMapSizeLimit { size, max_size } => { + serde_json::json!({ + "size": size, + "maxSize": max_size, + }) + } + SparkError::CollectionSizeLimitExceeded { + num_elements, + max_elements, + } => { + serde_json::json!({ + "numElements": num_elements, + "maxElements": max_elements, + }) + } + SparkError::NotNullAssertViolation { field_name } => { + serde_json::json!({ + "fieldName": field_name, + }) + } + SparkError::ValueIsNull { + field_name, + row_index, + } => { + serde_json::json!({ + "fieldName": field_name, + "rowIndex": row_index, + }) + } + SparkError::CannotParseTimestamp { + message, + suggested_func, + } => { + serde_json::json!({ + "message": message, + "suggestedFunc": suggested_func, + }) + } + SparkError::InvalidFractionOfSecond { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::InvalidUtf8String { hex_string } => { + serde_json::json!({ + "hexString": hex_string, + }) + } + SparkError::UnexpectedPositiveValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::UnexpectedNegativeValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::InvalidRegexGroupIndex { + function_name, + group_count, + group_index, + } => { + serde_json::json!({ + "functionName": function_name, + "groupCount": group_count, + "groupIndex": group_index, + }) + } + SparkError::DatatypeCannotOrder { data_type } => { + serde_json::json!({ + "dataType": data_type, + }) + } + SparkError::Arrow(e) => { + serde_json::json!({ + "message": e.to_string(), + }) + } + SparkError::Internal(msg) => { + serde_json::json!({ + "message": msg, + }) + } + // Simple errors with no parameters + _ => serde_json::json!({}), + } + } + + /// Returns the appropriate Spark exception class for this error + pub fn exception_class(&self) -> &'static str { + match self { + // ArithmeticException + SparkError::DivideByZero + | SparkError::RemainderByZero + | SparkError::IntervalDividedByZero + | SparkError::NumericValueOutOfRange { .. } + | SparkError::NumericOutOfRange { .. } // Comet-specific extension + | SparkError::ArithmeticOverflow { .. } + | SparkError::IntegralDivideOverflow + | SparkError::DecimalSumOverflow + | SparkError::BinaryArithmeticOverflow { .. } + | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } + | SparkError::IntervalArithmeticOverflowWithoutSuggestion + | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", + + // CastOverflow gets special handling with CastOverflowException + SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", + + // NumberFormatException (for cast invalid input errors) + SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", + + // ArrayIndexOutOfBoundsException + SparkError::InvalidArrayIndex { .. } + | SparkError::InvalidElementAtIndex { .. } + | SparkError::InvalidBitmapPosition { .. } + | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", + + // RuntimeException + SparkError::CannotParseDecimal + | SparkError::DuplicatedMapKey { .. } + | SparkError::NullMapKey + | SparkError::MapKeyValueDiffSizes + | SparkError::ExceedMapSizeLimit { .. } + | SparkError::CollectionSizeLimitExceeded { .. } + | SparkError::NotNullAssertViolation { .. } + | SparkError::ValueIsNull { .. } // Comet-specific extension + | SparkError::UnexpectedPositiveValue { .. } + | SparkError::UnexpectedNegativeValue { .. } + | SparkError::InvalidRegexGroupIndex { .. } + | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", + + // DateTimeException + SparkError::CannotParseTimestamp { .. } + | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", + + // IllegalArgumentException + SparkError::DatatypeCannotOrder { .. } + | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", + + // Generic errors + SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", + } + } + + /// Returns the Spark error class code for this error + pub(crate) fn error_class(&self) -> Option<&'static str> { + match self { + // Cast errors + SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), + SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), + SparkError::NumericValueOutOfRange { .. } => { + Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") + } + SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), + SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), + + // Arithmetic errors + SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), + SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), + SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), + SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), + SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), + + // Array index errors + SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), + SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), + SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), + SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), + + // Map/Collection errors + SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), + SparkError::NullMapKey => Some("NULL_MAP_KEY"), + SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), + SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), + SparkError::CollectionSizeLimitExceeded { .. } => { + Some("COLLECTION_SIZE_LIMIT_EXCEEDED") + } + + // Null validation errors + SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), + SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), + + // DateTime errors + SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), + SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), + + // String/UTF8 errors + SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), + + // Function parameter errors + SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), + SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), + + // Regex errors + SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), + + // Unsupported operation errors + SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), + + // Subquery errors + SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), + + // Generic errors (no error class) + SparkError::Arrow(_) | SparkError::Internal(_) => None, + } + } +} + +pub type SparkResult = Result; + +/// Convert decimal overflow to SparkError::NumericValueOutOfRange. +pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { + SparkError::NumericValueOutOfRange { + value: value.to_string(), + precision, + scale, + } +} + +/// Wrapper that adds QueryContext to SparkError +/// +/// This allows attaching SQL context information (query text, line/position, object name) to errors +#[derive(Debug, Clone)] +pub struct SparkErrorWithContext { + /// The underlying SparkError + pub error: SparkError, + /// Optional QueryContext for SQL location information + pub context: Option>, +} + +impl SparkErrorWithContext { + /// Create a SparkErrorWithContext without context + pub fn new(error: SparkError) -> Self { + Self { + error, + context: None, + } + } + + /// Create a SparkErrorWithContext with QueryContext + pub fn with_context(error: SparkError, context: Arc) -> Self { + Self { + error, + context: Some(context), + } + } + + /// Serialize to JSON including optional context field + pub fn to_json(&self) -> String { + let mut json_obj = serde_json::json!({ + "errorType": self.error.error_type_name(), + "errorClass": self.error.error_class().unwrap_or(""), + "params": self.error.params_as_json(), + }); + + if let Some(ctx) = &self.context { + // Serialize context fields + json_obj["context"] = serde_json::json!({ + "sqlText": ctx.sql_text.as_str(), + "startIndex": ctx.start_index, + "stopIndex": ctx.stop_index, + "objectType": ctx.object_type, + "objectName": ctx.object_name, + "line": ctx.line, + "startPosition": ctx.start_position, + }); + + // Add formatted summary + json_obj["summary"] = serde_json::json!(ctx.format_summary()); + } + + serde_json::to_string(&json_obj).unwrap_or_else(|e| { + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + }) + } +} + +impl std::fmt::Display for SparkErrorWithContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.error)?; + if let Some(ctx) = &self.context { + write!(f, "\n{}", ctx.format_summary())?; + } + Ok(()) + } +} + +impl std::error::Error for SparkErrorWithContext {} + +impl From for SparkErrorWithContext { + fn from(error: SparkError) -> Self { + SparkErrorWithContext::new(error) + } +} + +impl From for DataFusionError { + fn from(value: SparkErrorWithContext) -> Self { + DataFusionError::External(Box::new(value)) + } +} + +impl From for SparkError { + fn from(value: ArrowError) -> Self { + SparkError::Arrow(Arc::new(value)) + } +} + +impl From for DataFusionError { + fn from(value: SparkError) -> Self { + DataFusionError::External(Box::new(value)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_divide_by_zero_json() { + let error = SparkError::DivideByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"DivideByZero\"")); + assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DivideByZero"); + assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); + } + + #[test] + fn test_remainder_by_zero_json() { + let error = SparkError::RemainderByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"RemainderByZero\"")); + assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); + } + + #[test] + fn test_binary_overflow_json() { + let error = SparkError::BinaryArithmeticOverflow { + value1: "32767".to_string(), + symbol: "+".to_string(), + value2: "1".to_string(), + function_name: "try_add".to_string(), + }; + let json = error.to_json(); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); + assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); + assert_eq!(parsed["params"]["value1"], "32767"); + assert_eq!(parsed["params"]["symbol"], "+"); + assert_eq!(parsed["params"]["value2"], "1"); + assert_eq!(parsed["params"]["functionName"], "try_add"); + } + + #[test] + fn test_invalid_array_index_json() { + let error = SparkError::InvalidArrayIndex { + index_value: 10, + array_size: 3, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "InvalidArrayIndex"); + assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); + assert_eq!(parsed["params"]["indexValue"], 10); + assert_eq!(parsed["params"]["arraySize"], 3); + } + + #[test] + fn test_numeric_value_out_of_range_json() { + let error = SparkError::NumericValueOutOfRange { + value: "999.99".to_string(), + precision: 5, + scale: 2, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); + assert_eq!( + parsed["errorClass"], + "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" + ); + assert_eq!(parsed["params"]["value"], "999.99"); + assert_eq!(parsed["params"]["precision"], 5); + assert_eq!(parsed["params"]["scale"], 2); + } + + #[test] + fn test_cast_invalid_value_json() { + let error = SparkError::CastInvalidValue { + value: "abc".to_string(), + from_type: "STRING".to_string(), + to_type: "INT".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "CastInvalidValue"); + assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); + assert_eq!(parsed["params"]["value"], "abc"); + assert_eq!(parsed["params"]["fromType"], "STRING"); + assert_eq!(parsed["params"]["toType"], "INT"); + } + + #[test] + fn test_duplicated_map_key_json() { + let error = SparkError::DuplicatedMapKey { + key: "duplicate_key".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DuplicatedMapKey"); + assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); + assert_eq!(parsed["params"]["key"], "duplicate_key"); + } + + #[test] + fn test_null_map_key_json() { + let error = SparkError::NullMapKey; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NullMapKey"); + assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); + // Params should be an empty object + assert_eq!(parsed["params"], serde_json::json!({})); + } + + #[test] + fn test_error_class_mapping() { + // Test that error_class() returns the correct error class + assert_eq!( + SparkError::DivideByZero.error_class(), + Some("DIVIDE_BY_ZERO") + ); + assert_eq!( + SparkError::RemainderByZero.error_class(), + Some("REMAINDER_BY_ZERO") + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .error_class(), + Some("INVALID_ARRAY_INDEX") + ); + assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); + } + + #[test] + fn test_exception_class_mapping() { + // Test that exception_class() returns the correct Java exception class + assert_eq!( + SparkError::DivideByZero.exception_class(), + "org/apache/spark/SparkArithmeticException" + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .exception_class(), + "org/apache/spark/SparkArrayIndexOutOfBoundsException" + ); + assert_eq!( + SparkError::NullMapKey.exception_class(), + "org/apache/spark/SparkRuntimeException" + ); + } +} diff --git a/native/common/src/lib.rs b/native/common/src/lib.rs new file mode 100644 index 0000000000..9319d7347f --- /dev/null +++ b/native/common/src/lib.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod error; +mod query_context; + +pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; +pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; diff --git a/native/common/src/query_context.rs b/native/common/src/query_context.rs new file mode 100644 index 0000000000..10ecef6550 --- /dev/null +++ b/native/common/src/query_context.rs @@ -0,0 +1,403 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Query execution context for error reporting +//! +//! This module provides QueryContext which mirrors Spark's SQLQueryContext +//! for providing SQL text, line/position information, and error location +//! pointers in exception messages. + +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// Based on Spark's SQLQueryContext for error reporting. +/// +/// Contains information about where an error occurred in a SQL query, +/// including the full SQL text, line/column positions, and object context. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct QueryContext { + /// Full SQL query text + #[serde(rename = "sqlText")] + pub sql_text: Arc, + + /// Start offset in SQL text (0-based, character index) + #[serde(rename = "startIndex")] + pub start_index: i32, + + /// Stop offset in SQL text (0-based, character index, inclusive) + #[serde(rename = "stopIndex")] + pub stop_index: i32, + + /// Object type (e.g., "VIEW", "Project", "Filter") + #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] + pub object_type: Option, + + /// Object name (e.g., view name, column name) + #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] + pub object_name: Option, + + /// Line number in SQL query (1-based) + pub line: i32, + + /// Column position within the line (0-based) + #[serde(rename = "startPosition")] + pub start_position: i32, +} + +impl QueryContext { + #[allow(clippy::too_many_arguments)] + pub fn new( + sql_text: String, + start_index: i32, + stop_index: i32, + object_type: Option, + object_name: Option, + line: i32, + start_position: i32, + ) -> Self { + Self { + sql_text: Arc::new(sql_text), + start_index, + stop_index, + object_type, + object_name, + line, + start_position, + } + } + + /// Convert a character index to a byte offset in the SQL text. + /// Returns None if the character index is out of range. + fn char_index_to_byte_offset(&self, char_index: usize) -> Option { + self.sql_text + .char_indices() + .nth(char_index) + .map(|(byte_offset, _)| byte_offset) + } + + /// Generate a summary string showing SQL fragment with error location. + /// (From SQLQueryContext.summary) + /// + /// Format example: + /// ```text + /// == SQL of VIEW v1 (line 1, position 8) == + /// SELECT a/b FROM t + /// ^^^ + /// ``` + pub fn format_summary(&self) -> String { + let start_char = self.start_index.max(0) as usize; + // stop_index is inclusive; fragment covers [start, stop] + let stop_char = (self.stop_index + 1).max(0) as usize; + + let fragment = match ( + self.char_index_to_byte_offset(start_char), + // stop_char may equal sql_text.chars().count() (one past the end) + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], + _ => "", + }; + + // Build the header line + let mut summary = String::from("== SQL"); + + if let Some(obj_type) = &self.object_type { + if !obj_type.is_empty() { + summary.push_str(" of "); + summary.push_str(obj_type); + + if let Some(obj_name) = &self.object_name { + if !obj_name.is_empty() { + summary.push(' '); + summary.push_str(obj_name); + } + } + } + } + + summary.push_str(&format!( + " (line {}, position {}) ==\n", + self.line, + self.start_position + 1 // Convert 0-based to 1-based for display + )); + + // Add the SQL text with fragment highlighted + summary.push_str(&self.sql_text); + summary.push('\n'); + + // Add caret pointer + let caret_position = self.start_position.max(0) as usize; + summary.push_str(&" ".repeat(caret_position)); + // fragment.chars().count() gives the correct display width for non-ASCII + summary.push_str(&"^".repeat(fragment.chars().count().max(1))); + + summary + } + + /// Returns the SQL fragment that caused the error. + #[cfg(test)] + fn fragment(&self) -> String { + let start_char = self.start_index.max(0) as usize; + let stop_char = (self.stop_index + 1).max(0) as usize; + + match ( + self.char_index_to_byte_offset(start_char), + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), + _ => String::new(), + } + } +} + +use std::collections::HashMap; +use std::sync::RwLock; + +/// Map that stores QueryContext information for expressions during execution. +/// +/// This map is populated during plan deserialization and accessed +/// during error creation to attach SQL context to exceptions. +#[derive(Debug)] +pub struct QueryContextMap { + /// Map from expression ID to QueryContext + contexts: RwLock>>, +} + +impl QueryContextMap { + pub fn new() -> Self { + Self { + contexts: RwLock::new(HashMap::new()), + } + } + + /// Register a QueryContext for an expression ID. + /// + /// If the expression ID already exists, it will be replaced. + /// + /// # Arguments + /// * `expr_id` - Unique expression identifier from protobuf + /// * `context` - QueryContext containing SQL text and position info + pub fn register(&self, expr_id: u64, context: QueryContext) { + let mut contexts = self.contexts.write().unwrap(); + contexts.insert(expr_id, Arc::new(context)); + } + + /// Get the QueryContext for an expression ID. + /// + /// Returns None if no context is registered for this expression. + /// + /// # Arguments + /// * `expr_id` - Expression identifier to look up + pub fn get(&self, expr_id: u64) -> Option> { + let contexts = self.contexts.read().unwrap(); + contexts.get(&expr_id).cloned() + } + + /// Clear all registered contexts. + /// + /// This is typically called after plan execution completes to free memory. + pub fn clear(&self) { + let mut contexts = self.contexts.write().unwrap(); + contexts.clear(); + } + + /// Return the number of registered contexts (for debugging/testing) + pub fn len(&self) -> usize { + let contexts = self.contexts.read().unwrap(); + contexts.len() + } + + /// Check if the map is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Default for QueryContextMap { + fn default() -> Self { + Self::new() + } +} + +/// Create a new session-scoped QueryContextMap. +/// +/// This should be called once per SessionContext during plan creation +/// and passed to expressions that need query context for error reporting. +pub fn create_query_context_map() -> Arc { + Arc::new(QueryContextMap::new()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query_context_creation() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); + assert_eq!(ctx.start_index, 7); + assert_eq!(ctx.stop_index, 9); + assert_eq!(ctx.object_type, Some("Divide".to_string())); + assert_eq!(ctx.object_name, Some("a/b".to_string())); + assert_eq!(ctx.line, 1); + assert_eq!(ctx.start_position, 7); + } + + #[test] + fn test_query_context_serialization() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + let json = serde_json::to_string(&ctx).unwrap(); + let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); + + assert_eq!(ctx, deserialized); + } + + #[test] + fn test_format_summary() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("VIEW".to_string()), + Some("v1".to_string()), + 1, + 7, + ); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + assert!(summary.contains("^^^")); // Three carets for "a/b" + } + + #[test] + fn test_format_summary_without_object() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + } + + #[test] + fn test_fragment() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + assert_eq!(ctx.fragment(), "a/b"); + } + + #[test] + fn test_arc_string_sharing() { + let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let ctx2 = ctx1.clone(); + + // Arc should share the same allocation + assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); + } + + #[test] + fn test_json_with_optional_fields() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let json = serde_json::to_string(&ctx).unwrap(); + + // Should not serialize objectType and objectName when None + assert!(!json.contains("objectType")); + assert!(!json.contains("objectName")); + } + + #[test] + fn test_map_register_and_get() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx.clone()); + + let retrieved = map.get(1).unwrap(); + assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); + assert_eq!(retrieved.start_index, 7); + } + + #[test] + fn test_map_get_nonexistent() { + let map = QueryContextMap::new(); + assert!(map.get(999).is_none()); + } + + #[test] + fn test_map_clear() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx); + assert_eq!(map.len(), 1); + + map.clear(); + assert_eq!(map.len(), 0); + assert!(map.is_empty()); + } + + // Verify that fragment() and format_summary() correctly handle SQL text that + // contains multi-byte characters + + #[test] + fn test_fragment_non_ascii_accented() { + // "é" is a 2-byte UTF-8 sequence (U+00E9). + // SQL: "SELECT café FROM t" + // 0123456789... + // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. + // start_index=7, stop_index=10 should yield "café" + let sql = "SELECT café FROM t".to_string(); + let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); + assert_eq!(ctx.fragment(), "café"); + } +} diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 9c4ec9775c..db8a75849e 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -45,8 +45,6 @@ tokio = { version = "1", features = ["rt-multi-thread"] } async-trait = { workspace = true } log = "0.4" log4rs = "1.4.0" -thiserror = { workspace = true } -lazy_static = "1.4.0" prost = "0.14.3" jni = "0.21" snap = "1.1" @@ -64,10 +62,10 @@ datafusion-physical-expr-adapter = { workspace = true } datafusion-datasource = { workspace = true } datafusion-spark = { workspace = true } once_cell = "1.18.0" -regex = { workspace = true } crc32fast = "1.3.2" simd-adler32 = "0.3.7" datafusion-comet-spark-expr = { workspace = true } +datafusion-comet-jni-bridge = { workspace = true } datafusion-comet-proto = { workspace = true } object_store = { workspace = true } url = { workspace = true } @@ -108,7 +106,7 @@ jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"] # exclude optional packages from cargo machete verifications [package.metadata.cargo-machete] -ignored = ["datafusion-comet-objectstore-hdfs", "hdfs-sys"] +ignored = ["hdfs-sys", "paste"] [lib] name = "comet" diff --git a/native/core/src/execution/expressions/subquery.rs b/native/core/src/execution/expressions/subquery.rs index 52f9d13f12..ad4106c251 100644 --- a/native/core/src/execution/expressions/subquery.rs +++ b/native/core/src/execution/expressions/subquery.rs @@ -17,7 +17,7 @@ use crate::{ execution::utils::bytes_to_i128, - jvm_bridge::{jni_static_call, BinaryWrapper, JVMClasses, StringWrapper}, + jvm_bridge::{BinaryWrapper, JVMClasses, StringWrapper}, }; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Schema, TimeUnit}; diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 361deae182..59ac674431 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -24,7 +24,7 @@ use crate::{ metrics::utils::update_comet_metric, planner::PhysicalPlanner, serde::to_arrow_datatype, shuffle::spark_unsafe::row::process_sorted_row_partition, sort::RdxSort, }, - jvm_bridge::{jni_new_global_ref, JVMClasses}, + jvm_bridge::JVMClasses, }; use arrow::array::{Array, RecordBatch, UInt32Array}; use arrow::compute::{take, TakeOptions}; diff --git a/native/core/src/execution/memory_pools/fair_pool.rs b/native/core/src/execution/memory_pools/fair_pool.rs index 1a98f91e49..2c25fe9443 100644 --- a/native/core/src/execution/memory_pools/fair_pool.rs +++ b/native/core/src/execution/memory_pools/fair_pool.rs @@ -22,10 +22,7 @@ use std::{ use jni::objects::GlobalRef; -use crate::{ - errors::CometResult, - jvm_bridge::{jni_call, JVMClasses}, -}; +use crate::{errors::CometResult, jvm_bridge::JVMClasses}; use datafusion::common::resources_err; use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::{ diff --git a/native/core/src/execution/memory_pools/unified_pool.rs b/native/core/src/execution/memory_pools/unified_pool.rs index 88b2731072..3233dd6d40 100644 --- a/native/core/src/execution/memory_pools/unified_pool.rs +++ b/native/core/src/execution/memory_pools/unified_pool.rs @@ -23,10 +23,7 @@ use std::{ }, }; -use crate::{ - errors::CometResult, - jvm_bridge::{jni_call, JVMClasses}, -}; +use crate::{errors::CometResult, jvm_bridge::JVMClasses}; use datafusion::{ common::{resources_datafusion_err, DataFusionError}, execution::memory_pool::{MemoryPool, MemoryReservation}, diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 9ec35ad951..161c1f1cf9 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. +use crate::errors::CometError; use crate::execution::spark_plan::SparkPlan; -use crate::{errors::CometError, jvm_bridge::jni_call}; use datafusion::physical_plan::metrics::MetricValue; use datafusion_comet_proto::spark_metric::NativeMetricNode; use jni::{objects::JObject, JNIEnv}; diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 07ee995367..3c3814a2b5 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -17,9 +17,7 @@ //! Operators -use std::fmt::Debug; - -use jni::objects::GlobalRef; +pub use crate::errors::ExecutionError; pub use copy::*; pub use iceberg_scan::*; @@ -35,31 +33,3 @@ mod csv_scan; pub mod projection; mod scan; pub use csv_scan::init_csv_datasource_exec; - -/// Error returned during executing operators. -#[derive(thiserror::Error, Debug)] -pub enum ExecutionError { - /// Simple error - #[allow(dead_code)] - #[error("General execution error with reason: {0}.")] - GeneralError(String), - - /// Error when deserializing an operator. - #[error("Fail to deserialize to native operator with reason: {0}.")] - DeserializeError(String), - - /// Error when processing Arrow array. - #[error("Fail to process Arrow array with reason: {0}.")] - ArrowError(String), - - /// DataFusion error - #[error("Error from DataFusion: {0}.")] - DataFusionError(String), - - #[error("{class}: {msg}")] - JavaException { - class: String, - msg: String, - throwable: GlobalRef, - }, -} diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 2543705fb0..2394912e41 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -21,7 +21,7 @@ use crate::{ execution::{ operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, utils::SparkArrowConvert, }, - jvm_bridge::{jni_call, JVMClasses}, + jvm_bridge::JVMClasses, }; use arrow::array::{make_array, ArrayData, ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::compute::{cast_with_options, take, CastOptions}; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index b79b43f6c9..c9b9d4ab63 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -23,16 +23,13 @@ pub mod operator_registry; use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::IcebergScanExec; -use crate::{ - errors::ExpressionError, - execution::{ - expressions::subquery::Subquery, - operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, - planner::expression_registry::ExpressionRegistry, - planner::operator_registry::OperatorRegistry, - serde::to_arrow_datatype, - shuffle::ShuffleWriterExec, - }, +use crate::execution::{ + expressions::subquery::Subquery, + operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + planner::expression_registry::ExpressionRegistry, + planner::operator_registry::OperatorRegistry, + serde::to_arrow_datatype, + shuffle::ShuffleWriterExec, }; use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field, FieldRef, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; @@ -2626,24 +2623,6 @@ impl PhysicalPlanner { } } -impl From for ExecutionError { - fn from(value: DataFusionError) -> Self { - ExecutionError::DataFusionError(value.message().to_string()) - } -} - -impl From for DataFusionError { - fn from(value: ExecutionError) -> Self { - DataFusionError::Execution(value.to_string()) - } -} - -impl From for DataFusionError { - fn from(value: ExpressionError) -> Self { - DataFusionError::Execution(value.to_string()) - } -} - /// Collects the indices of the columns in the input schema that are used in the expression /// and returns them as a pair of vectors, one for the left side and one for the right side. fn expr_to_columns( diff --git a/native/core/src/execution/serde.rs b/native/core/src/execution/serde.rs index e95fd7eca2..ae0554ee76 100644 --- a/native/core/src/execution/serde.rs +++ b/native/core/src/execution/serde.rs @@ -34,30 +34,6 @@ use datafusion_comet_proto::{ use prost::Message; use std::{io::Cursor, sync::Arc}; -impl From for ExpressionError { - fn from(error: prost::DecodeError) -> ExpressionError { - ExpressionError::Deserialize(error.to_string()) - } -} - -impl From for ExpressionError { - fn from(error: prost::UnknownEnumValue) -> ExpressionError { - ExpressionError::Deserialize(error.to_string()) - } -} - -impl From for ExecutionError { - fn from(error: prost::DecodeError) -> ExecutionError { - ExecutionError::DeserializeError(error.to_string()) - } -} - -impl From for ExecutionError { - fn from(error: prost::UnknownEnumValue) -> ExecutionError { - ExecutionError::DeserializeError(error.to_string()) - } -} - /// Deserialize bytes to protobuf type of expression pub fn deserialize_expr(buf: &[u8]) -> Result { match spark_expression::Expr::decode(&mut Cursor::new(buf)) { diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 9e6f2a56e7..f95423aa70 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -16,32 +16,12 @@ // under the License. /// Utils for array vector, etc. -use crate::errors::ExpressionError; use crate::execution::operators::ExecutionError; use arrow::{ array::ArrayData, - error::ArrowError, ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; -impl From for ExecutionError { - fn from(error: ArrowError) -> ExecutionError { - ExecutionError::ArrowError(error.to_string()) - } -} - -impl From for ExpressionError { - fn from(error: ArrowError) -> ExpressionError { - ExpressionError::ArrowError(error.to_string()) - } -} - -impl From for ArrowError { - fn from(error: ExpressionError) -> ArrowError { - ArrowError::ComputeError(error.to_string()) - } -} - pub trait SparkArrowConvert { /// Build Arrow Arrays from C data interface passed from Spark. /// It accepts a tuple (ArrowArray address, ArrowSchema address). diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 2b883bd7df..1b87dc1dba 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -26,9 +26,12 @@ #![deny(clippy::clone_on_ref_ptr)] extern crate core; +#[macro_use] +extern crate datafusion_comet_jni_bridge; + use jni::{ objects::{JClass, JString}, - JNIEnv, JavaVM, + JNIEnv, }; use log::info; use log4rs::{ @@ -37,7 +40,6 @@ use log4rs::{ encode::pattern::PatternEncoder, Config, }; -use once_cell::sync::OnceCell; #[cfg(all( not(target_env = "msvc"), @@ -52,14 +54,20 @@ use tikv_jemallocator::Jemalloc; ))] use mimalloc::MiMalloc; +// Re-export from jvm-bridge crate for internal use +pub use datafusion_comet_jni_bridge::errors; +pub use datafusion_comet_jni_bridge::JAVA_VM; + +/// Re-export jvm-bridge items under the `jvm_bridge` name for convenience. +pub mod jvm_bridge { + pub use datafusion_comet_jni_bridge::*; +} + use errors::{try_unwrap_or_throw, CometError, CometResult}; -#[macro_use] -mod errors; #[macro_use] pub mod common; pub mod execution; -mod jvm_bridge; pub mod parquet; #[cfg(all( @@ -77,8 +85,6 @@ static GLOBAL: Jemalloc = Jemalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; -static JAVA_VM: OnceCell = OnceCell::new(); - #[no_mangle] pub extern "system" fn Java_org_apache_comet_NativeBase_init( e: JNIEnv, diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index f2b0e80ab2..f386b850d4 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -53,7 +53,7 @@ use crate::execution::planner::PhysicalPlanner; use crate::execution::serde; use crate::execution::spark_plan::SparkPlan; use crate::execution::utils::SparkArrowConvert; -use crate::jvm_bridge::{jni_new_global_ref, JVMClasses}; +use crate::jvm_bridge::JVMClasses; use crate::parquet::data_type::AsBytes; use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use crate::parquet::parquet_exec::init_datasource_exec; diff --git a/native/jni-bridge/Cargo.toml b/native/jni-bridge/Cargo.toml new file mode 100644 index 0000000000..0c50825667 --- /dev/null +++ b/native/jni-bridge/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-comet-jni-bridge" +version = { workspace = true } +homepage = "https://datafusion.apache.org/comet" +repository = "https://github.com/apache/datafusion-comet" +authors = ["Apache DataFusion "] +description = "Apache DataFusion Comet: JNI bridge" +readme = "README.md" +license = "Apache-2.0" +edition = "2021" + +publish = false + +[dependencies] +arrow = { workspace = true } +parquet = { workspace = true } +datafusion = { workspace = true } +jni = "0.21" +thiserror = { workspace = true } +regex = { workspace = true } +lazy_static = "1.4.0" +once_cell = "1.18.0" +paste = "1.0.14" +prost = "0.14.3" +datafusion-comet-common = { workspace = true } + +[dev-dependencies] +jni = { version = "0.21", features = ["invocation"] } +assertables = "9" diff --git a/native/core/src/jvm_bridge/batch_iterator.rs b/native/jni-bridge/src/batch_iterator.rs similarity index 100% rename from native/core/src/jvm_bridge/batch_iterator.rs rename to native/jni-bridge/src/batch_iterator.rs diff --git a/native/core/src/jvm_bridge/comet_exec.rs b/native/jni-bridge/src/comet_exec.rs similarity index 100% rename from native/core/src/jvm_bridge/comet_exec.rs rename to native/jni-bridge/src/comet_exec.rs diff --git a/native/core/src/jvm_bridge/comet_metric_node.rs b/native/jni-bridge/src/comet_metric_node.rs similarity index 100% rename from native/core/src/jvm_bridge/comet_metric_node.rs rename to native/jni-bridge/src/comet_metric_node.rs diff --git a/native/core/src/jvm_bridge/comet_task_memory_manager.rs b/native/jni-bridge/src/comet_task_memory_manager.rs similarity index 100% rename from native/core/src/jvm_bridge/comet_task_memory_manager.rs rename to native/jni-bridge/src/comet_task_memory_manager.rs diff --git a/native/core/src/errors.rs b/native/jni-bridge/src/errors.rs similarity index 90% rename from native/core/src/errors.rs rename to native/jni-bridge/src/errors.rs index 7c8957dba7..aff471e245 100644 --- a/native/core/src/errors.rs +++ b/native/jni-bridge/src/errors.rs @@ -19,6 +19,7 @@ use arrow::error::ArrowError; use datafusion::common::DataFusionError; +use datafusion_comet_common::{SparkError, SparkErrorWithContext}; use jni::errors::{Exception, ToException}; use regex::Regex; @@ -37,8 +38,6 @@ use std::{ // lifetime checker won't let us. use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort}; -use crate::execution::operators::ExecutionError; -use datafusion_comet_spark_expr::SparkError; use jni::objects::{GlobalRef, JThrowable}; use jni::JNIEnv; use lazy_static::lazy_static; @@ -49,6 +48,34 @@ lazy_static! { static ref PANIC_BACKTRACE: Arc>> = Arc::new(Mutex::new(None)); } +/// Error returned during executing operators. +#[derive(thiserror::Error, Debug)] +pub enum ExecutionError { + /// Simple error + #[allow(dead_code)] + #[error("General execution error with reason: {0}.")] + GeneralError(String), + + /// Error when deserializing an operator. + #[error("Fail to deserialize to native operator with reason: {0}.")] + DeserializeError(String), + + /// Error when processing Arrow array. + #[error("Fail to process Arrow array with reason: {0}.")] + ArrowError(String), + + /// DataFusion error + #[error("Error from DataFusion: {0}.")] + DataFusionError(String), + + #[error("{class}: {msg}")] + JavaException { + class: String, + msg: String, + throwable: GlobalRef, + }, +} + #[derive(thiserror::Error, Debug)] pub enum CometError { #[error("Configuration Error: {0}")] @@ -63,11 +90,6 @@ pub enum CometError { #[error("Comet Internal Error: {0}")] Internal(String), - /// CometError::Spark is typically used in native code to emulate the same errors - /// that Spark would return - #[error(transparent)] - Spark(SparkError), - #[error(transparent)] Arrow { #[from] @@ -215,6 +237,66 @@ impl From for ExecutionError { } } +impl From for ExpressionError { + fn from(error: prost::DecodeError) -> ExpressionError { + ExpressionError::Deserialize(error.to_string()) + } +} + +impl From for ExpressionError { + fn from(error: prost::UnknownEnumValue) -> ExpressionError { + ExpressionError::Deserialize(error.to_string()) + } +} + +impl From for ExecutionError { + fn from(error: prost::DecodeError) -> ExecutionError { + ExecutionError::DeserializeError(error.to_string()) + } +} + +impl From for ExecutionError { + fn from(error: prost::UnknownEnumValue) -> ExecutionError { + ExecutionError::DeserializeError(error.to_string()) + } +} + +impl From for ExecutionError { + fn from(error: ArrowError) -> ExecutionError { + ExecutionError::ArrowError(error.to_string()) + } +} + +impl From for ExpressionError { + fn from(error: ArrowError) -> ExpressionError { + ExpressionError::ArrowError(error.to_string()) + } +} + +impl From for ArrowError { + fn from(error: ExpressionError) -> ArrowError { + ArrowError::ComputeError(error.to_string()) + } +} + +impl From for ExecutionError { + fn from(value: DataFusionError) -> Self { + ExecutionError::DataFusionError(value.message().to_string()) + } +} + +impl From for DataFusionError { + fn from(value: ExecutionError) -> Self { + DataFusionError::Execution(value.to_string()) + } +} + +impl From for DataFusionError { + fn from(value: ExpressionError) -> Self { + DataFusionError::Execution(value.to_string()) + } +} + impl jni::errors::ToException for CometError { fn to_exception(&self) -> Exception { match self { @@ -226,10 +308,6 @@ impl jni::errors::ToException for CometError { class: "java/lang/NullPointerException".to_string(), msg: self.to_string(), }, - CometError::Spark(spark_err) => Exception { - class: spark_err.exception_class().to_string(), - msg: spark_err.to_string(), - }, CometError::NumberIntFormat { source: s } => Exception { class: "java/lang/NumberFormatException".to_string(), msg: s.to_string(), @@ -280,8 +358,9 @@ pub type CometResult = result::Result; // ---------------------------------------------------------------------- // Convenient macros for different errors +#[macro_export] macro_rules! general_err { - ($fmt:expr, $($args:expr),*) => (crate::CometError::from(parquet::errors::ParquetError::General(format!($fmt, $($args),*)))); + ($fmt:expr, $($args:expr),*) => ($crate::errors::CometError::from(parquet::errors::ParquetError::General(format!($fmt, $($args),*)))); } /// Returns the "default value" for a type. This is used for JNI code in order to facilitate @@ -395,25 +474,25 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw(<&JThrowable>::from(throwable.as_obj())), - // Handle DataFusion errors containing SparkError - serialize to JSON + // Handle DataFusion errors containing SparkError or SparkErrorWithContext CometError::DataFusion { msg: _, source: DataFusionError::External(e), } => { - // Try SparkErrorWithContext first (includes context) - if let Some(spark_error_with_ctx) = - e.downcast_ref::() - { + if let Some(spark_error_with_ctx) = e.downcast_ref::() { let json_message = spark_error_with_ctx.to_json(); env.throw_new( "org/apache/comet/exceptions/CometQueryExecutionException", json_message, ) } else if let Some(spark_error) = e.downcast_ref::() { - // Fall back to plain SparkError (no context) - throw_spark_error_as_json(env, spark_error) + let json_message = spark_error.to_json(); + env.throw_new( + "org/apache/comet/exceptions/CometQueryExecutionException", + json_message, + ) } else { - // Not a SparkError, use generic exception + // Fall through to generic exception let exception = error.to_exception(); match backtrace { Some(backtrace_string) => env.throw_new( @@ -424,8 +503,6 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option throw_spark_error_as_json(env, spark_error), _ => { let exception = error.to_exception(); match backtrace { @@ -441,21 +518,6 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option jni::errors::Result<()> { - // Serialize error to JSON - let json_message = spark_error.to_json(); - - // Throw CometQueryExecutionException with JSON message - env.throw_new( - "org/apache/comet/exceptions/CometQueryExecutionException", - json_message, - ) -} - #[derive(Debug, Error)] enum StacktraceError { #[error("Unable to initialize message: {0}")] @@ -463,6 +525,7 @@ enum StacktraceError { #[error("Unable to initialize backtrace regex: {0}")] Regex(#[from] regex::Error), #[error("Required field missing: {0}")] + #[allow(non_camel_case_types)] Required_Field(String), #[error("Unable to format stacktrace element: {0}")] Element(#[from] std::fmt::Error), diff --git a/native/core/src/jvm_bridge/mod.rs b/native/jni-bridge/src/lib.rs similarity index 87% rename from native/core/src/jvm_bridge/mod.rs rename to native/jni-bridge/src/lib.rs index 00fe7b33c3..456fbdf688 100644 --- a/native/core/src/jvm_bridge/mod.rs +++ b/native/jni-bridge/src/lib.rs @@ -15,19 +15,28 @@ // specific language governing permissions and limitations // under the License. -//! JNI JVM related functions +//! JNI bridge for Apache DataFusion Comet. +//! +//! This crate provides the JNI interaction layer used across Comet's native Rust crates. -use crate::errors::CometResult; +#![allow(clippy::result_large_err)] use jni::objects::JClass; use jni::{ errors::Error, objects::{JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned}, signature::ReturnType, - AttachGuard, JNIEnv, + AttachGuard, JNIEnv, JavaVM, }; use once_cell::sync::OnceCell; +use errors::{CometError, CometResult}; + +pub mod errors; + +/// Global reference to the Java VM, initialized during native library setup. +pub static JAVA_VM: OnceCell = OnceCell::new(); + /// Macro for converting JNI Error to Comet Error. #[macro_export] macro_rules! jni_map_error { @@ -40,6 +49,7 @@ macro_rules! jni_map_error { } /// Macro for converting Rust types to JNI types. +#[macro_export] macro_rules! jvalues { ($($args:expr,)* $(,)?) => {{ &[$(jni::objects::JValue::from($args).as_jni()),*] as &[jni::sys::jvalue] @@ -53,57 +63,67 @@ macro_rules! jvalues { /// metric_node is the Java object on which the method is called. /// add is the method name. /// jname and value are the arguments. +#[macro_export] macro_rules! jni_call { ($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ let method_id = paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + $crate::JVMClasses::get().[<$clsname>].[] }; let ret_type = paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + $crate::JVMClasses::get().[<$clsname>].[] }.clone(); - let args = $crate::jvm_bridge::jvalues!($($args,)*); + let args = $crate::jvalues!($($args,)*); // Call the JVM method and obtain the returned value let ret = $env.call_method_unchecked($obj, method_id, ret_type, args); // Check if JVM has thrown any exception, and handle it if so. - let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? { + let result = if let Some(exception) = $crate::check_exception($env)? { Err(exception.into()) } else { - $crate::jvm_bridge::jni_map_error!($env, ret) + $crate::jni_map_error!($env, ret) }; - result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result))) + result.and_then(|result| $crate::jni_map_error!($env, <$ret>::try_from(result))) }} } +#[macro_export] macro_rules! jni_static_call { ($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ let clazz = &paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + $crate::JVMClasses::get().[<$clsname>].[] }; let method_id = paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + $crate::JVMClasses::get().[<$clsname>].[] }; let ret_type = paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] + $crate::JVMClasses::get().[<$clsname>].[] }.clone(); - let args = $crate::jvm_bridge::jvalues!($($args,)*); + let args = $crate::jvalues!($($args,)*); // Call the JVM static method and obtain the returned value let ret = $env.call_static_method_unchecked(clazz, method_id, ret_type, args); // Check if JVM has thrown any exception, and handle it if so. - let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? { + let result = if let Some(exception) = $crate::check_exception($env)? { Err(exception.into()) } else { - $crate::jvm_bridge::jni_map_error!($env, ret) + $crate::jni_map_error!($env, ret) }; - result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result))) + result.and_then(|result| $crate::jni_map_error!($env, <$ret>::try_from(result))) }} } +/// Macro for creating a new global reference. +#[macro_export] +macro_rules! jni_new_global_ref { + ($env:expr, $obj:expr) => {{ + $crate::jni_map_error!($env, $env.new_global_ref($obj)) + }}; +} + /// Wrapper for JString. Because we cannot implement `TryFrom` trait for `JString` as they /// are defined in different crates. pub struct StringWrapper<'a> { @@ -156,26 +176,12 @@ impl<'a> TryFrom> for BinaryWrapper<'a> { } } -/// Macro for creating a new global reference. -macro_rules! jni_new_global_ref { - ($env:expr, $obj:expr) => {{ - $crate::jni_map_error!($env, $env.new_global_ref($obj)) - }}; -} - -pub(crate) use jni_call; -pub(crate) use jni_map_error; -pub(crate) use jni_new_global_ref; -pub(crate) use jni_static_call; -pub(crate) use jvalues; - mod comet_exec; pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; -use crate::{errors::CometError, JAVA_VM}; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; @@ -190,13 +196,13 @@ pub struct JVMClasses<'a> { /// Cached JClass for "java.lang.Throwable" java_lang_throwable: JClass<'a>, /// Cached method ID for "java.lang.Object#getClass" - pub object_get_class_method: JMethodID, + object_get_class_method: JMethodID, /// Cached method ID for "java.lang.Class#getName" - pub class_get_name_method: JMethodID, + class_get_name_method: JMethodID, /// Cached method ID for "java.lang.Throwable#getMessage" - pub throwable_get_message_method: JMethodID, + throwable_get_message_method: JMethodID, /// Cached method ID for "java.lang.Throwable#getCause" - pub throwable_get_cause_method: JMethodID, + throwable_get_cause_method: JMethodID, /// The CometMetricNode class. Used for updating the metrics. pub comet_metric_node: CometMetricNode<'a>, @@ -287,7 +293,7 @@ impl JVMClasses<'_> { } } -pub(crate) fn check_exception(env: &mut JNIEnv) -> CometResult> { +pub fn check_exception(env: &mut JNIEnv) -> CometResult> { let result = if env.exception_check()? { let exception = env.exception_occurred()?; env.exception_clear()?; @@ -380,10 +386,7 @@ fn get_throwable_message( /// this converts it into a `CometError::JavaException` with the exception class name /// and exception message. This error can then be populated to the JVM side to let /// users know the cause of the native side error. -pub(crate) fn convert_exception( - env: &mut JNIEnv, - throwable: &JThrowable, -) -> CometResult { +pub fn convert_exception(env: &mut JNIEnv, throwable: &JThrowable) -> CometResult { let cache = JVMClasses::get(); let exception_class_name_str = get_throwable_class_name(env, cache, throwable)?; let message_str = get_throwable_message(env, cache, throwable)?; diff --git a/native/core/testdata/backtrace.txt b/native/jni-bridge/testdata/backtrace.txt similarity index 100% rename from native/core/testdata/backtrace.txt rename to native/jni-bridge/testdata/backtrace.txt diff --git a/native/core/testdata/stacktrace.txt b/native/jni-bridge/testdata/stacktrace.txt similarity index 100% rename from native/core/testdata/stacktrace.txt rename to native/jni-bridge/testdata/stacktrace.txt diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 9f08e480f2..a249b7de3c 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -33,9 +33,8 @@ datafusion = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } -serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -thiserror = { workspace = true } +datafusion-comet-common = { workspace = true } futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs index ae3b5c0eda..bb87915c7b 100644 --- a/native/spark-expr/src/error.rs +++ b/native/spark-expr/src/error.rs @@ -15,855 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::error::ArrowError; -use datafusion::common::DataFusionError; -use std::sync::Arc; - -#[derive(thiserror::Error, Debug, Clone)] -pub enum SparkError { - // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet - #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - because it is malformed. Correct the value as per the syntax, or change its target type. \ - Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastInvalidValue { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] - NumericValueOutOfRange { - value: String, - precision: u8, - scale: i8, - }, - - #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")] - NumericOutOfRange { value: String }, - - #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastOverFlow { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] - CannotParseDecimal, - - #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - ArithmeticOverflow { from_type: String }, - - #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntegralDivideOverflow, - - #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - DecimalSumOverflow, - - #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - DivideByZero, - - #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - RemainderByZero, - - #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] - IntervalDividedByZero, - - #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - BinaryArithmeticOverflow { - value1: String, - symbol: String, - value2: String, - function_name: String, - }, - - #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntervalArithmeticOverflowWithSuggestion { function_name: String }, - - #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntervalArithmeticOverflowWithoutSuggestion, - - #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] - DatetimeOverflow, - - #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidArrayIndex { index_value: i32, array_size: i32 }, - - #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidElementAtIndex { index_value: i32, array_size: i32 }, - - #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] - InvalidBitmapPosition { - bit_position: i64, - bitmap_num_bytes: i64, - bitmap_num_bits: i64, - }, - - #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] - InvalidIndexOfZero, - - #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] - DuplicatedMapKey { key: String }, - - #[error("[NULL_MAP_KEY] Cannot use null as map key.")] - NullMapKey, - - #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] - MapKeyValueDiffSizes, - - #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] - ExceedMapSizeLimit { size: i32, max_size: i32 }, - - #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] - CollectionSizeLimitExceeded { - num_elements: i64, - max_elements: i64, - }, - - #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] - NotNullAssertViolation { field_name: String }, - - #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] - ValueIsNull { field_name: String, row_index: i32 }, - - #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] - CannotParseTimestamp { - message: String, - suggested_func: String, - }, - - #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidFractionOfSecond { value: f64 }, - - #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] - InvalidUtf8String { hex_string: String }, - - #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] - UnexpectedPositiveValue { - parameter_name: String, - actual_value: i32, - }, - - #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] - UnexpectedNegativeValue { - parameter_name: String, - actual_value: i32, - }, - - #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] - InvalidRegexGroupIndex { - function_name: String, - group_count: i32, - group_index: i32, - }, - - #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] - DatatypeCannotOrder { data_type: String }, - - #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] - ScalarSubqueryTooManyRows, - - #[error("ArrowError: {0}.")] - Arrow(Arc), - - #[error("InternalError: {0}.")] - Internal(String), -} - -impl SparkError { - /// Serialize this error to JSON format for JNI transfer - pub fn to_json(&self) -> String { - let error_class = self.error_class().unwrap_or(""); - - // Create a JSON structure with errorType, errorClass, and params - match serde_json::to_string(&serde_json::json!({ - "errorType": self.error_type_name(), - "errorClass": error_class, - "params": self.params_as_json(), - })) { - Ok(json) => json, - Err(e) => { - // Fallback if serialization fails - format!( - "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", - e - ) - } - } - } - - /// Get the error type name for JSON serialization - fn error_type_name(&self) -> &'static str { - match self { - SparkError::CastInvalidValue { .. } => "CastInvalidValue", - SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", - SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", - SparkError::CastOverFlow { .. } => "CastOverFlow", - SparkError::CannotParseDecimal => "CannotParseDecimal", - SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", - SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", - SparkError::DecimalSumOverflow => "DecimalSumOverflow", - SparkError::DivideByZero => "DivideByZero", - SparkError::RemainderByZero => "RemainderByZero", - SparkError::IntervalDividedByZero => "IntervalDividedByZero", - SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", - SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { - "IntervalArithmeticOverflowWithSuggestion" - } - SparkError::IntervalArithmeticOverflowWithoutSuggestion => { - "IntervalArithmeticOverflowWithoutSuggestion" - } - SparkError::DatetimeOverflow => "DatetimeOverflow", - SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", - SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", - SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", - SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", - SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", - SparkError::NullMapKey => "NullMapKey", - SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", - SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", - SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", - SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", - SparkError::ValueIsNull { .. } => "ValueIsNull", - SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", - SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", - SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", - SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", - SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", - SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", - SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", - SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", - SparkError::Arrow(_) => "Arrow", - SparkError::Internal(_) => "Internal", - } - } - - /// Extract parameters as JSON value - fn params_as_json(&self) -> serde_json::Value { - match self { - SparkError::CastInvalidValue { - value, - from_type, - to_type, - } => { - serde_json::json!({ - "value": value, - "fromType": from_type, - "toType": to_type, - }) - } - SparkError::NumericValueOutOfRange { - value, - precision, - scale, - } => { - serde_json::json!({ - "value": value, - "precision": precision, - "scale": scale, - }) - } - SparkError::NumericOutOfRange { value } => { - serde_json::json!({ - "value": value, - }) - } - SparkError::CastOverFlow { - value, - from_type, - to_type, - } => { - serde_json::json!({ - "value": value, - "fromType": from_type, - "toType": to_type, - }) - } - SparkError::ArithmeticOverflow { from_type } => { - serde_json::json!({ - "fromType": from_type, - }) - } - SparkError::BinaryArithmeticOverflow { - value1, - symbol, - value2, - function_name, - } => { - serde_json::json!({ - "value1": value1, - "symbol": symbol, - "value2": value2, - "functionName": function_name, - }) - } - SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { - serde_json::json!({ - "functionName": function_name, - }) - } - SparkError::InvalidArrayIndex { - index_value, - array_size, - } => { - serde_json::json!({ - "indexValue": index_value, - "arraySize": array_size, - }) - } - SparkError::InvalidElementAtIndex { - index_value, - array_size, - } => { - serde_json::json!({ - "indexValue": index_value, - "arraySize": array_size, - }) - } - SparkError::InvalidBitmapPosition { - bit_position, - bitmap_num_bytes, - bitmap_num_bits, - } => { - serde_json::json!({ - "bitPosition": bit_position, - "bitmapNumBytes": bitmap_num_bytes, - "bitmapNumBits": bitmap_num_bits, - }) - } - SparkError::DuplicatedMapKey { key } => { - serde_json::json!({ - "key": key, - }) - } - SparkError::ExceedMapSizeLimit { size, max_size } => { - serde_json::json!({ - "size": size, - "maxSize": max_size, - }) - } - SparkError::CollectionSizeLimitExceeded { - num_elements, - max_elements, - } => { - serde_json::json!({ - "numElements": num_elements, - "maxElements": max_elements, - }) - } - SparkError::NotNullAssertViolation { field_name } => { - serde_json::json!({ - "fieldName": field_name, - }) - } - SparkError::ValueIsNull { - field_name, - row_index, - } => { - serde_json::json!({ - "fieldName": field_name, - "rowIndex": row_index, - }) - } - SparkError::CannotParseTimestamp { - message, - suggested_func, - } => { - serde_json::json!({ - "message": message, - "suggestedFunc": suggested_func, - }) - } - SparkError::InvalidFractionOfSecond { value } => { - serde_json::json!({ - "value": value, - }) - } - SparkError::InvalidUtf8String { hex_string } => { - serde_json::json!({ - "hexString": hex_string, - }) - } - SparkError::UnexpectedPositiveValue { - parameter_name, - actual_value, - } => { - serde_json::json!({ - "parameterName": parameter_name, - "actualValue": actual_value, - }) - } - SparkError::UnexpectedNegativeValue { - parameter_name, - actual_value, - } => { - serde_json::json!({ - "parameterName": parameter_name, - "actualValue": actual_value, - }) - } - SparkError::InvalidRegexGroupIndex { - function_name, - group_count, - group_index, - } => { - serde_json::json!({ - "functionName": function_name, - "groupCount": group_count, - "groupIndex": group_index, - }) - } - SparkError::DatatypeCannotOrder { data_type } => { - serde_json::json!({ - "dataType": data_type, - }) - } - SparkError::Arrow(e) => { - serde_json::json!({ - "message": e.to_string(), - }) - } - SparkError::Internal(msg) => { - serde_json::json!({ - "message": msg, - }) - } - // Simple errors with no parameters - _ => serde_json::json!({}), - } - } - - /// Returns the appropriate Spark exception class for this error - pub fn exception_class(&self) -> &'static str { - match self { - // ArithmeticException - SparkError::DivideByZero - | SparkError::RemainderByZero - | SparkError::IntervalDividedByZero - | SparkError::NumericValueOutOfRange { .. } - | SparkError::NumericOutOfRange { .. } // Comet-specific extension - | SparkError::ArithmeticOverflow { .. } - | SparkError::IntegralDivideOverflow - | SparkError::DecimalSumOverflow - | SparkError::BinaryArithmeticOverflow { .. } - | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } - | SparkError::IntervalArithmeticOverflowWithoutSuggestion - | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", - - // CastOverflow gets special handling with CastOverflowException - SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", - - // NumberFormatException (for cast invalid input errors) - SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", - - // ArrayIndexOutOfBoundsException - SparkError::InvalidArrayIndex { .. } - | SparkError::InvalidElementAtIndex { .. } - | SparkError::InvalidBitmapPosition { .. } - | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", - - // RuntimeException - SparkError::CannotParseDecimal - | SparkError::DuplicatedMapKey { .. } - | SparkError::NullMapKey - | SparkError::MapKeyValueDiffSizes - | SparkError::ExceedMapSizeLimit { .. } - | SparkError::CollectionSizeLimitExceeded { .. } - | SparkError::NotNullAssertViolation { .. } - | SparkError::ValueIsNull { .. } // Comet-specific extension - | SparkError::UnexpectedPositiveValue { .. } - | SparkError::UnexpectedNegativeValue { .. } - | SparkError::InvalidRegexGroupIndex { .. } - | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", - - // DateTimeException - SparkError::CannotParseTimestamp { .. } - | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", - - // IllegalArgumentException - SparkError::DatatypeCannotOrder { .. } - | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", - - // Generic errors - SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", - } - } - - /// Returns the Spark error class code for this error - pub fn error_class(&self) -> Option<&'static str> { - match self { - // Cast errors - SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), - SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), - SparkError::NumericValueOutOfRange { .. } => { - Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") - } - SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), - SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), - - // Arithmetic errors - SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), - SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), - SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), - SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), - SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), - SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), - SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), - SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { - Some("INTERVAL_ARITHMETIC_OVERFLOW") - } - SparkError::IntervalArithmeticOverflowWithoutSuggestion => { - Some("INTERVAL_ARITHMETIC_OVERFLOW") - } - SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), - - // Array index errors - SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), - SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), - SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), - SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), - - // Map/Collection errors - SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), - SparkError::NullMapKey => Some("NULL_MAP_KEY"), - SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), - SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), - SparkError::CollectionSizeLimitExceeded { .. } => { - Some("COLLECTION_SIZE_LIMIT_EXCEEDED") - } - - // Null validation errors - SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), - SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), - - // DateTime errors - SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), - SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), - - // String/UTF8 errors - SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), - - // Function parameter errors - SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), - SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), - - // Regex errors - SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), - - // Unsupported operation errors - SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), - - // Subquery errors - SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), - - // Generic errors (no error class) - SparkError::Arrow(_) | SparkError::Internal(_) => None, - } - } -} - -/// Convert decimal overflow to SparkError::NumericValueOutOfRange. -/// -/// Creates the appropriate SparkError when a decimal value exceeds the precision limit for Decimal128 storage. -/// -/// # Arguments -/// * `value` - The i128 decimal value that overflowed -/// * `precision` - The target precision -/// * `scale` - The scale of the decimal -/// -/// # Returns -/// SparkError::NumericValueOutOfRange with the value, precision, and scale -pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { - SparkError::NumericValueOutOfRange { - value: value.to_string(), - precision, - scale, - } -} - -pub type SparkResult = Result; - -/// Wrapper that adds QueryContext to SparkError -/// -/// This allows attaching SQL context information (query text, line/position, object name) to errors -#[derive(Debug, Clone)] -pub struct SparkErrorWithContext { - /// The underlying SparkError - pub error: SparkError, - /// Optional QueryContext for SQL location information - pub context: Option>, -} - -impl SparkErrorWithContext { - /// Create a SparkErrorWithContext without context - pub fn new(error: SparkError) -> Self { - Self { - error, - context: None, - } - } - - /// Create a SparkErrorWithContext with QueryContext - pub fn with_context(error: SparkError, context: Arc) -> Self { - Self { - error, - context: Some(context), - } - } - - /// Serialize to JSON including optional context field - /// - /// JSON structure: - /// ```json - /// { - /// "errorType": "DivideByZero", - /// "errorClass": "DIVIDE_BY_ZERO", - /// "params": {}, - /// "context": { - /// "sqlText": "SELECT a/b FROM t", - /// "startIndex": 7, - /// "stopIndex": 9, - /// "line": 1, - /// "startPosition": 7 - /// }, - /// "summary": "== SQL (line 1, position 8) ==\n..." - /// } - /// ``` - pub fn to_json(&self) -> String { - let mut json_obj = serde_json::json!({ - "errorType": self.error.error_type_name(), - "errorClass": self.error.error_class().unwrap_or(""), - "params": self.error.params_as_json(), - }); - - if let Some(ctx) = &self.context { - // Serialize context fields - json_obj["context"] = serde_json::json!({ - "sqlText": ctx.sql_text.as_str(), - "startIndex": ctx.start_index, - "stopIndex": ctx.stop_index, - "objectType": ctx.object_type, - "objectName": ctx.object_name, - "line": ctx.line, - "startPosition": ctx.start_position, - }); - - // Add formatted summary - json_obj["summary"] = serde_json::json!(ctx.format_summary()); - } - - serde_json::to_string(&json_obj).unwrap_or_else(|e| { - format!( - "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", - e - ) - }) - } -} - -impl std::fmt::Display for SparkErrorWithContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.error)?; - if let Some(ctx) = &self.context { - write!(f, "\n{}", ctx.format_summary())?; - } - Ok(()) - } -} - -impl std::error::Error for SparkErrorWithContext {} - -impl From for SparkErrorWithContext { - fn from(error: SparkError) -> Self { - SparkErrorWithContext::new(error) - } -} - -impl From for DataFusionError { - fn from(value: SparkErrorWithContext) -> Self { - DataFusionError::External(Box::new(value)) - } -} - -impl From for SparkError { - fn from(value: ArrowError) -> Self { - SparkError::Arrow(Arc::new(value)) - } -} - -impl From for DataFusionError { - fn from(value: SparkError) -> Self { - DataFusionError::External(Box::new(value)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_divide_by_zero_json() { - let error = SparkError::DivideByZero; - let json = error.to_json(); - - assert!(json.contains("\"errorType\":\"DivideByZero\"")); - assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); - - // Verify it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "DivideByZero"); - assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); - } - - #[test] - fn test_remainder_by_zero_json() { - let error = SparkError::RemainderByZero; - let json = error.to_json(); - - assert!(json.contains("\"errorType\":\"RemainderByZero\"")); - assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); - } - - #[test] - fn test_binary_overflow_json() { - let error = SparkError::BinaryArithmeticOverflow { - value1: "32767".to_string(), - symbol: "+".to_string(), - value2: "1".to_string(), - function_name: "try_add".to_string(), - }; - let json = error.to_json(); - - // Verify it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); - assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); - assert_eq!(parsed["params"]["value1"], "32767"); - assert_eq!(parsed["params"]["symbol"], "+"); - assert_eq!(parsed["params"]["value2"], "1"); - assert_eq!(parsed["params"]["functionName"], "try_add"); - } - - #[test] - fn test_invalid_array_index_json() { - let error = SparkError::InvalidArrayIndex { - index_value: 10, - array_size: 3, - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "InvalidArrayIndex"); - assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); - assert_eq!(parsed["params"]["indexValue"], 10); - assert_eq!(parsed["params"]["arraySize"], 3); - } - - #[test] - fn test_numeric_value_out_of_range_json() { - let error = SparkError::NumericValueOutOfRange { - value: "999.99".to_string(), - precision: 5, - scale: 2, - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); - assert_eq!( - parsed["errorClass"], - "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" - ); - assert_eq!(parsed["params"]["value"], "999.99"); - assert_eq!(parsed["params"]["precision"], 5); - assert_eq!(parsed["params"]["scale"], 2); - } - - #[test] - fn test_cast_invalid_value_json() { - let error = SparkError::CastInvalidValue { - value: "abc".to_string(), - from_type: "STRING".to_string(), - to_type: "INT".to_string(), - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "CastInvalidValue"); - assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); - assert_eq!(parsed["params"]["value"], "abc"); - assert_eq!(parsed["params"]["fromType"], "STRING"); - assert_eq!(parsed["params"]["toType"], "INT"); - } - - #[test] - fn test_duplicated_map_key_json() { - let error = SparkError::DuplicatedMapKey { - key: "duplicate_key".to_string(), - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "DuplicatedMapKey"); - assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); - assert_eq!(parsed["params"]["key"], "duplicate_key"); - } - - #[test] - fn test_null_map_key_json() { - let error = SparkError::NullMapKey; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "NullMapKey"); - assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); - // Params should be an empty object - assert_eq!(parsed["params"], serde_json::json!({})); - } - - #[test] - fn test_error_class_mapping() { - // Test that error_class() returns the correct error class - assert_eq!( - SparkError::DivideByZero.error_class(), - Some("DIVIDE_BY_ZERO") - ); - assert_eq!( - SparkError::RemainderByZero.error_class(), - Some("REMAINDER_BY_ZERO") - ); - assert_eq!( - SparkError::InvalidArrayIndex { - index_value: 0, - array_size: 0 - } - .error_class(), - Some("INVALID_ARRAY_INDEX") - ); - assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); - } - - #[test] - fn test_exception_class_mapping() { - // Test that exception_class() returns the correct Java exception class - assert_eq!( - SparkError::DivideByZero.exception_class(), - "org/apache/spark/SparkArithmeticException" - ); - assert_eq!( - SparkError::InvalidArrayIndex { - index_value: 0, - array_size: 0 - } - .exception_class(), - "org/apache/spark/SparkArrayIndexOutOfBoundsException" - ); - assert_eq!( - SparkError::NullMapKey.exception_class(), - "org/apache/spark/SparkRuntimeException" - ); - } -} +// Re-export all error types from the common crate +pub use datafusion_comet_common::{ + decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult, +}; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 072fa1fad7..d576c803a1 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -74,7 +74,7 @@ pub use datetime_funcs::{ SparkDateDiff, SparkDateTrunc, SparkHour, SparkMakeDate, SparkMinute, SparkSecond, SparkUnixTimestamp, TimestampTruncExpr, }; -pub use error::{SparkError, SparkErrorWithContext, SparkResult}; +pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ diff --git a/native/spark-expr/src/query_context.rs b/native/spark-expr/src/query_context.rs index e6591135e0..7d96a3f1e5 100644 --- a/native/spark-expr/src/query_context.rs +++ b/native/spark-expr/src/query_context.rs @@ -15,388 +15,5 @@ // specific language governing permissions and limitations // under the License. -//! Query execution context for error reporting -//! -//! This module provides QueryContext which mirrors Spark's SQLQueryContext -//! for providing SQL text, line/position information, and error location -//! pointers in exception messages. - -use serde::{Deserialize, Serialize}; -use std::sync::Arc; - -/// Based on Spark's SQLQueryContext for error reporting. -/// -/// Contains information about where an error occurred in a SQL query, -/// including the full SQL text, line/column positions, and object context. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct QueryContext { - /// Full SQL query text - #[serde(rename = "sqlText")] - pub sql_text: Arc, - - /// Start offset in SQL text (0-based, character index) - #[serde(rename = "startIndex")] - pub start_index: i32, - - /// Stop offset in SQL text (0-based, character index, inclusive) - #[serde(rename = "stopIndex")] - pub stop_index: i32, - - /// Object type (e.g., "VIEW", "Project", "Filter") - #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] - pub object_type: Option, - - /// Object name (e.g., view name, column name) - #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] - pub object_name: Option, - - /// Line number in SQL query (1-based) - pub line: i32, - - /// Column position within the line (0-based) - #[serde(rename = "startPosition")] - pub start_position: i32, -} - -impl QueryContext { - #[allow(clippy::too_many_arguments)] - pub fn new( - sql_text: String, - start_index: i32, - stop_index: i32, - object_type: Option, - object_name: Option, - line: i32, - start_position: i32, - ) -> Self { - Self { - sql_text: Arc::new(sql_text), - start_index, - stop_index, - object_type, - object_name, - line, - start_position, - } - } - - /// Convert a character index to a byte offset in the SQL text. - /// Returns None if the character index is out of range. - fn char_index_to_byte_offset(&self, char_index: usize) -> Option { - self.sql_text - .char_indices() - .nth(char_index) - .map(|(byte_offset, _)| byte_offset) - } - - /// Generate a summary string showing SQL fragment with error location. - /// (From SQLQueryContext.summary) - /// - /// Format example: - /// ```text - /// == SQL of VIEW v1 (line 1, position 8) == - /// SELECT a/b FROM t - /// ^^^ - /// ``` - pub fn format_summary(&self) -> String { - let start_char = self.start_index.max(0) as usize; - // stop_index is inclusive; fragment covers [start, stop] - let stop_char = (self.stop_index + 1).max(0) as usize; - - let fragment = match ( - self.char_index_to_byte_offset(start_char), - // stop_char may equal sql_text.chars().count() (one past the end) - self.char_index_to_byte_offset(stop_char).or_else(|| { - if stop_char == self.sql_text.chars().count() { - Some(self.sql_text.len()) - } else { - None - } - }), - ) { - (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], - _ => "", - }; - - // Build the header line - let mut summary = String::from("== SQL"); - - if let Some(obj_type) = &self.object_type { - if !obj_type.is_empty() { - summary.push_str(" of "); - summary.push_str(obj_type); - - if let Some(obj_name) = &self.object_name { - if !obj_name.is_empty() { - summary.push(' '); - summary.push_str(obj_name); - } - } - } - } - - summary.push_str(&format!( - " (line {}, position {}) ==\n", - self.line, - self.start_position + 1 // Convert 0-based to 1-based for display - )); - - // Add the SQL text with fragment highlighted - summary.push_str(&self.sql_text); - summary.push('\n'); - - // Add caret pointer - let caret_position = self.start_position.max(0) as usize; - summary.push_str(&" ".repeat(caret_position)); - // fragment.chars().count() gives the correct display width for non-ASCII - summary.push_str(&"^".repeat(fragment.chars().count().max(1))); - - summary - } - - /// Returns the SQL fragment that caused the error. - pub fn fragment(&self) -> String { - let start_char = self.start_index.max(0) as usize; - let stop_char = (self.stop_index + 1).max(0) as usize; - - match ( - self.char_index_to_byte_offset(start_char), - self.char_index_to_byte_offset(stop_char).or_else(|| { - if stop_char == self.sql_text.chars().count() { - Some(self.sql_text.len()) - } else { - None - } - }), - ) { - (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), - _ => String::new(), - } - } -} - -use std::collections::HashMap; -use std::sync::RwLock; - -/// Map that stores QueryContext information for expressions during execution. -/// -/// This map is populated during plan deserialization and accessed -/// during error creation to attach SQL context to exceptions. -#[derive(Debug)] -pub struct QueryContextMap { - /// Map from expression ID to QueryContext - contexts: RwLock>>, -} - -impl QueryContextMap { - pub fn new() -> Self { - Self { - contexts: RwLock::new(HashMap::new()), - } - } - - /// Register a QueryContext for an expression ID. - /// - /// If the expression ID already exists, it will be replaced. - /// - /// # Arguments - /// * `expr_id` - Unique expression identifier from protobuf - /// * `context` - QueryContext containing SQL text and position info - pub fn register(&self, expr_id: u64, context: QueryContext) { - let mut contexts = self.contexts.write().unwrap(); - contexts.insert(expr_id, Arc::new(context)); - } - - /// Get the QueryContext for an expression ID. - /// - /// Returns None if no context is registered for this expression. - /// - /// # Arguments - /// * `expr_id` - Expression identifier to look up - pub fn get(&self, expr_id: u64) -> Option> { - let contexts = self.contexts.read().unwrap(); - contexts.get(&expr_id).cloned() - } - - /// Clear all registered contexts. - /// - /// This is typically called after plan execution completes to free memory. - pub fn clear(&self) { - let mut contexts = self.contexts.write().unwrap(); - contexts.clear(); - } - - /// Return the number of registered contexts (for debugging/testing) - pub fn len(&self) -> usize { - let contexts = self.contexts.read().unwrap(); - contexts.len() - } - - /// Check if the map is empty - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl Default for QueryContextMap { - fn default() -> Self { - Self::new() - } -} - -/// Create a new session-scoped QueryContextMap. -/// -/// This should be called once per SessionContext during plan creation -/// and passed to expressions that need query context for error reporting. -pub fn create_query_context_map() -> Arc { - Arc::new(QueryContextMap::new()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_query_context_creation() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("Divide".to_string()), - Some("a/b".to_string()), - 1, - 7, - ); - - assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); - assert_eq!(ctx.start_index, 7); - assert_eq!(ctx.stop_index, 9); - assert_eq!(ctx.object_type, Some("Divide".to_string())); - assert_eq!(ctx.object_name, Some("a/b".to_string())); - assert_eq!(ctx.line, 1); - assert_eq!(ctx.start_position, 7); - } - - #[test] - fn test_query_context_serialization() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("Divide".to_string()), - Some("a/b".to_string()), - 1, - 7, - ); - - let json = serde_json::to_string(&ctx).unwrap(); - let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); - - assert_eq!(ctx, deserialized); - } - - #[test] - fn test_format_summary() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("VIEW".to_string()), - Some("v1".to_string()), - 1, - 7, - ); - - let summary = ctx.format_summary(); - - assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); - assert!(summary.contains("SELECT a/b FROM t")); - assert!(summary.contains("^^^")); // Three carets for "a/b" - } - - #[test] - fn test_format_summary_without_object() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let summary = ctx.format_summary(); - - assert!(summary.contains("== SQL (line 1, position 8) ==")); - assert!(summary.contains("SELECT a/b FROM t")); - } - - #[test] - fn test_fragment() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - assert_eq!(ctx.fragment(), "a/b"); - } - - #[test] - fn test_arc_string_sharing() { - let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let ctx2 = ctx1.clone(); - - // Arc should share the same allocation - assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); - } - - #[test] - fn test_json_with_optional_fields() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let json = serde_json::to_string(&ctx).unwrap(); - - // Should not serialize objectType and objectName when None - assert!(!json.contains("objectType")); - assert!(!json.contains("objectName")); - } - - #[test] - fn test_map_register_and_get() { - let map = QueryContextMap::new(); - - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - map.register(1, ctx.clone()); - - let retrieved = map.get(1).unwrap(); - assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); - assert_eq!(retrieved.start_index, 7); - } - - #[test] - fn test_map_get_nonexistent() { - let map = QueryContextMap::new(); - assert!(map.get(999).is_none()); - } - - #[test] - fn test_map_clear() { - let map = QueryContextMap::new(); - - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - map.register(1, ctx); - assert_eq!(map.len(), 1); - - map.clear(); - assert_eq!(map.len(), 0); - assert!(map.is_empty()); - } - - // Verify that fragment() and format_summary() correctly handle SQL text that - // contains multi-byte characters - - #[test] - fn test_fragment_non_ascii_accented() { - // "é" is a 2-byte UTF-8 sequence (U+00E9). - // SQL: "SELECT café FROM t" - // 0123456789... - // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. - // start_index=7, stop_index=10 should yield "café" - let sql = "SELECT café FROM t".to_string(); - let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); - assert_eq!(ctx.fragment(), "café"); - } -} +// Re-export all query context types from the common crate +pub use datafusion_comet_common::{create_query_context_map, QueryContext, QueryContextMap};