From 16e9c99be161d31ea002ab16ddf475a84f4c840a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 09:34:31 -0600 Subject: [PATCH 01/11] refactor: extract jvm-bridge as separate crate Extract JNI bridge code (errors, jvm_bridge, SparkError, QueryContext) from core and spark-expr into a new datafusion-comet-jvm-bridge crate. This fixes the dependency direction so that jvm-bridge is the lowest-level crate with no dependency on spark-expr. The dependency graph is now: jvm-bridge <- spark-expr <- core Moved to jvm-bridge: - errors.rs (CometError, ExecutionError, ExpressionError, JNI helpers) - jvm_bridge/ (JNI macros, JVMClasses, cached method IDs) - SparkError, SparkErrorWithContext, SparkResult, decimal_overflow_error - QueryContext, QueryContextMap, create_query_context_map spark-expr re-exports SparkError/QueryContext types for backward compatibility. --- native/Cargo.lock | 22 + native/Cargo.toml | 5 +- native/core/Cargo.toml | 1 + .../src/execution/expressions/subquery.rs | 2 +- native/core/src/execution/jni_api.rs | 2 +- .../src/execution/memory_pools/fair_pool.rs | 5 +- .../execution/memory_pools/unified_pool.rs | 5 +- native/core/src/execution/metrics/utils.rs | 2 +- native/core/src/execution/operators/mod.rs | 32 +- native/core/src/execution/operators/scan.rs | 2 +- native/core/src/execution/planner.rs | 35 +- native/core/src/execution/serde.rs | 24 - native/core/src/execution/utils.rs | 20 - native/core/src/lib.rs | 16 +- native/core/src/parquet/mod.rs | 2 +- native/jvm-bridge/Cargo.toml | 48 + native/{core => jvm-bridge}/src/errors.rs | 99 +- .../src/jvm_bridge/batch_iterator.rs | 0 .../src/jvm_bridge/comet_exec.rs | 0 .../src/jvm_bridge/comet_metric_node.rs | 0 .../jvm_bridge/comet_task_memory_manager.rs | 0 .../src/jvm_bridge/mod.rs | 43 +- native/jvm-bridge/src/lib.rs | 37 + native/jvm-bridge/src/query_context.rs | 402 ++++++++ native/jvm-bridge/src/spark_error.rs | 869 ++++++++++++++++++ .../testdata/backtrace.txt | 0 .../testdata/stacktrace.txt | 0 native/spark-expr/Cargo.toml | 1 + native/spark-expr/src/error.rs | 856 +---------------- native/spark-expr/src/query_context.rs | 389 +------- 30 files changed, 1527 insertions(+), 1392 deletions(-) create mode 100644 native/jvm-bridge/Cargo.toml rename native/{core => jvm-bridge}/src/errors.rs (91%) rename native/{core => jvm-bridge}/src/jvm_bridge/batch_iterator.rs (100%) rename native/{core => jvm-bridge}/src/jvm_bridge/comet_exec.rs (100%) rename native/{core => jvm-bridge}/src/jvm_bridge/comet_metric_node.rs (100%) rename native/{core => jvm-bridge}/src/jvm_bridge/comet_task_memory_manager.rs (100%) rename native/{core => jvm-bridge}/src/jvm_bridge/mod.rs (94%) create mode 100644 native/jvm-bridge/src/lib.rs create mode 100644 native/jvm-bridge/src/query_context.rs create mode 100644 native/jvm-bridge/src/spark_error.rs rename native/{core => jvm-bridge}/testdata/backtrace.txt (100%) rename native/{core => jvm-bridge}/testdata/stacktrace.txt (100%) diff --git a/native/Cargo.lock b/native/Cargo.lock index 465454adc5..0432c07813 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1837,6 +1837,7 @@ dependencies = [ "crc32fast", "criterion", "datafusion", + "datafusion-comet-jvm-bridge", "datafusion-comet-objectstore-hdfs", "datafusion-comet-proto", "datafusion-comet-spark-expr", @@ -1899,6 +1900,26 @@ dependencies = [ "uuid", ] +[[package]] +name = "datafusion-comet-jvm-bridge" +version = "0.14.0" +dependencies = [ + "arrow", + "assertables", + "datafusion", + "jni", + "lazy_static", + "log", + "once_cell", + "parquet", + "paste", + "prost", + "regex", + "serde", + "serde_json", + "thiserror 2.0.18", +] + [[package]] name = "datafusion-comet-objectstore-hdfs" version = "0.14.0" @@ -1931,6 +1952,7 @@ dependencies = [ "chrono-tz", "criterion", "datafusion", + "datafusion-comet-jvm-bridge", "futures", "hex", "num", diff --git a/native/Cargo.toml b/native/Cargo.toml index 7979978a31..847045a9eb 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [workspace] -default-members = ["core", "spark-expr", "proto"] -members = ["core", "spark-expr", "proto", "hdfs", "fs-hdfs"] +default-members = ["core", "spark-expr", "proto", "jvm-bridge"] +members = ["core", "spark-expr", "proto", "jvm-bridge", "hdfs", "fs-hdfs"] resolver = "2" [workspace.package] @@ -43,6 +43,7 @@ datafusion-datasource = { version = "52.2.0" } datafusion-physical-expr-adapter = { version = "52.2.0" } datafusion-spark = { version = "52.2.0" } datafusion-comet-spark-expr = { path = "spark-expr" } +datafusion-comet-jvm-bridge = { path = "jvm-bridge" } datafusion-comet-proto = { path = "proto" } chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.10" } diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 9c4ec9775c..e7d3c96ae6 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -68,6 +68,7 @@ regex = { workspace = true } crc32fast = "1.3.2" simd-adler32 = "0.3.7" datafusion-comet-spark-expr = { workspace = true } +datafusion-comet-jvm-bridge = { workspace = true } datafusion-comet-proto = { workspace = true } object_store = { workspace = true } url = { workspace = true } diff --git a/native/core/src/execution/expressions/subquery.rs b/native/core/src/execution/expressions/subquery.rs index 52f9d13f12..ad4106c251 100644 --- a/native/core/src/execution/expressions/subquery.rs +++ b/native/core/src/execution/expressions/subquery.rs @@ -17,7 +17,7 @@ use crate::{ execution::utils::bytes_to_i128, - jvm_bridge::{jni_static_call, BinaryWrapper, JVMClasses, StringWrapper}, + jvm_bridge::{BinaryWrapper, JVMClasses, StringWrapper}, }; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Schema, TimeUnit}; diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 361deae182..59ac674431 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -24,7 +24,7 @@ use crate::{ metrics::utils::update_comet_metric, planner::PhysicalPlanner, serde::to_arrow_datatype, shuffle::spark_unsafe::row::process_sorted_row_partition, sort::RdxSort, }, - jvm_bridge::{jni_new_global_ref, JVMClasses}, + jvm_bridge::JVMClasses, }; use arrow::array::{Array, RecordBatch, UInt32Array}; use arrow::compute::{take, TakeOptions}; diff --git a/native/core/src/execution/memory_pools/fair_pool.rs b/native/core/src/execution/memory_pools/fair_pool.rs index 1a98f91e49..2c25fe9443 100644 --- a/native/core/src/execution/memory_pools/fair_pool.rs +++ b/native/core/src/execution/memory_pools/fair_pool.rs @@ -22,10 +22,7 @@ use std::{ use jni::objects::GlobalRef; -use crate::{ - errors::CometResult, - jvm_bridge::{jni_call, JVMClasses}, -}; +use crate::{errors::CometResult, jvm_bridge::JVMClasses}; use datafusion::common::resources_err; use datafusion::execution::memory_pool::MemoryConsumer; use datafusion::{ diff --git a/native/core/src/execution/memory_pools/unified_pool.rs b/native/core/src/execution/memory_pools/unified_pool.rs index 88b2731072..3233dd6d40 100644 --- a/native/core/src/execution/memory_pools/unified_pool.rs +++ b/native/core/src/execution/memory_pools/unified_pool.rs @@ -23,10 +23,7 @@ use std::{ }, }; -use crate::{ - errors::CometResult, - jvm_bridge::{jni_call, JVMClasses}, -}; +use crate::{errors::CometResult, jvm_bridge::JVMClasses}; use datafusion::{ common::{resources_datafusion_err, DataFusionError}, execution::memory_pool::{MemoryPool, MemoryReservation}, diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 9ec35ad951..161c1f1cf9 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. +use crate::errors::CometError; use crate::execution::spark_plan::SparkPlan; -use crate::{errors::CometError, jvm_bridge::jni_call}; use datafusion::physical_plan::metrics::MetricValue; use datafusion_comet_proto::spark_metric::NativeMetricNode; use jni::{objects::JObject, JNIEnv}; diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 07ee995367..3c3814a2b5 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -17,9 +17,7 @@ //! Operators -use std::fmt::Debug; - -use jni::objects::GlobalRef; +pub use crate::errors::ExecutionError; pub use copy::*; pub use iceberg_scan::*; @@ -35,31 +33,3 @@ mod csv_scan; pub mod projection; mod scan; pub use csv_scan::init_csv_datasource_exec; - -/// Error returned during executing operators. -#[derive(thiserror::Error, Debug)] -pub enum ExecutionError { - /// Simple error - #[allow(dead_code)] - #[error("General execution error with reason: {0}.")] - GeneralError(String), - - /// Error when deserializing an operator. - #[error("Fail to deserialize to native operator with reason: {0}.")] - DeserializeError(String), - - /// Error when processing Arrow array. - #[error("Fail to process Arrow array with reason: {0}.")] - ArrowError(String), - - /// DataFusion error - #[error("Error from DataFusion: {0}.")] - DataFusionError(String), - - #[error("{class}: {msg}")] - JavaException { - class: String, - msg: String, - throwable: GlobalRef, - }, -} diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 2543705fb0..2394912e41 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -21,7 +21,7 @@ use crate::{ execution::{ operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, utils::SparkArrowConvert, }, - jvm_bridge::{jni_call, JVMClasses}, + jvm_bridge::JVMClasses, }; use arrow::array::{make_array, ArrayData, ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::compute::{cast_with_options, take, CastOptions}; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index b79b43f6c9..c9b9d4ab63 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -23,16 +23,13 @@ pub mod operator_registry; use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::IcebergScanExec; -use crate::{ - errors::ExpressionError, - execution::{ - expressions::subquery::Subquery, - operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, - planner::expression_registry::ExpressionRegistry, - planner::operator_registry::OperatorRegistry, - serde::to_arrow_datatype, - shuffle::ShuffleWriterExec, - }, +use crate::execution::{ + expressions::subquery::Subquery, + operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + planner::expression_registry::ExpressionRegistry, + planner::operator_registry::OperatorRegistry, + serde::to_arrow_datatype, + shuffle::ShuffleWriterExec, }; use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field, FieldRef, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; @@ -2626,24 +2623,6 @@ impl PhysicalPlanner { } } -impl From for ExecutionError { - fn from(value: DataFusionError) -> Self { - ExecutionError::DataFusionError(value.message().to_string()) - } -} - -impl From for DataFusionError { - fn from(value: ExecutionError) -> Self { - DataFusionError::Execution(value.to_string()) - } -} - -impl From for DataFusionError { - fn from(value: ExpressionError) -> Self { - DataFusionError::Execution(value.to_string()) - } -} - /// Collects the indices of the columns in the input schema that are used in the expression /// and returns them as a pair of vectors, one for the left side and one for the right side. fn expr_to_columns( diff --git a/native/core/src/execution/serde.rs b/native/core/src/execution/serde.rs index e95fd7eca2..ae0554ee76 100644 --- a/native/core/src/execution/serde.rs +++ b/native/core/src/execution/serde.rs @@ -34,30 +34,6 @@ use datafusion_comet_proto::{ use prost::Message; use std::{io::Cursor, sync::Arc}; -impl From for ExpressionError { - fn from(error: prost::DecodeError) -> ExpressionError { - ExpressionError::Deserialize(error.to_string()) - } -} - -impl From for ExpressionError { - fn from(error: prost::UnknownEnumValue) -> ExpressionError { - ExpressionError::Deserialize(error.to_string()) - } -} - -impl From for ExecutionError { - fn from(error: prost::DecodeError) -> ExecutionError { - ExecutionError::DeserializeError(error.to_string()) - } -} - -impl From for ExecutionError { - fn from(error: prost::UnknownEnumValue) -> ExecutionError { - ExecutionError::DeserializeError(error.to_string()) - } -} - /// Deserialize bytes to protobuf type of expression pub fn deserialize_expr(buf: &[u8]) -> Result { match spark_expression::Expr::decode(&mut Cursor::new(buf)) { diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 9e6f2a56e7..f95423aa70 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -16,32 +16,12 @@ // under the License. /// Utils for array vector, etc. -use crate::errors::ExpressionError; use crate::execution::operators::ExecutionError; use arrow::{ array::ArrayData, - error::ArrowError, ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; -impl From for ExecutionError { - fn from(error: ArrowError) -> ExecutionError { - ExecutionError::ArrowError(error.to_string()) - } -} - -impl From for ExpressionError { - fn from(error: ArrowError) -> ExpressionError { - ExpressionError::ArrowError(error.to_string()) - } -} - -impl From for ArrowError { - fn from(error: ExpressionError) -> ArrowError { - ArrowError::ComputeError(error.to_string()) - } -} - pub trait SparkArrowConvert { /// Build Arrow Arrays from C data interface passed from Spark. /// It accepts a tuple (ArrowArray address, ArrowSchema address). diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 2b883bd7df..8ee3139ff9 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -26,9 +26,12 @@ #![deny(clippy::clone_on_ref_ptr)] extern crate core; +#[macro_use] +extern crate datafusion_comet_jvm_bridge; + use jni::{ objects::{JClass, JString}, - JNIEnv, JavaVM, + JNIEnv, }; use log::info; use log4rs::{ @@ -37,7 +40,6 @@ use log4rs::{ encode::pattern::PatternEncoder, Config, }; -use once_cell::sync::OnceCell; #[cfg(all( not(target_env = "msvc"), @@ -52,14 +54,16 @@ use tikv_jemallocator::Jemalloc; ))] use mimalloc::MiMalloc; +// Re-export from jvm-bridge crate for internal use +pub use datafusion_comet_jvm_bridge::errors; +pub use datafusion_comet_jvm_bridge::jvm_bridge; +pub use datafusion_comet_jvm_bridge::JAVA_VM; + use errors::{try_unwrap_or_throw, CometError, CometResult}; -#[macro_use] -mod errors; #[macro_use] pub mod common; pub mod execution; -mod jvm_bridge; pub mod parquet; #[cfg(all( @@ -77,8 +81,6 @@ static GLOBAL: Jemalloc = Jemalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; -static JAVA_VM: OnceCell = OnceCell::new(); - #[no_mangle] pub extern "system" fn Java_org_apache_comet_NativeBase_init( e: JNIEnv, diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index f2b0e80ab2..f386b850d4 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -53,7 +53,7 @@ use crate::execution::planner::PhysicalPlanner; use crate::execution::serde; use crate::execution::spark_plan::SparkPlan; use crate::execution::utils::SparkArrowConvert; -use crate::jvm_bridge::{jni_new_global_ref, JVMClasses}; +use crate::jvm_bridge::JVMClasses; use crate::parquet::data_type::AsBytes; use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use crate::parquet::parquet_exec::init_datasource_exec; diff --git a/native/jvm-bridge/Cargo.toml b/native/jvm-bridge/Cargo.toml new file mode 100644 index 0000000000..5870d89b17 --- /dev/null +++ b/native/jvm-bridge/Cargo.toml @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-comet-jvm-bridge" +version = { workspace = true } +homepage = "https://datafusion.apache.org/comet" +repository = "https://github.com/apache/datafusion-comet" +authors = ["Apache DataFusion "] +description = "Apache DataFusion Comet: JVM bridge and error types" +readme = "README.md" +license = "Apache-2.0" +edition = "2021" + +publish = false + +[dependencies] +arrow = { workspace = true } +parquet = { workspace = true } +datafusion = { workspace = true } +jni = "0.21" +thiserror = { workspace = true } +regex = { workspace = true } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +lazy_static = "1.4.0" +once_cell = "1.18.0" +paste = "1.0.14" +log = "0.4" +prost = "0.14.3" + +[dev-dependencies] +jni = { version = "0.21", features = ["invocation"] } +assertables = "9" diff --git a/native/core/src/errors.rs b/native/jvm-bridge/src/errors.rs similarity index 91% rename from native/core/src/errors.rs rename to native/jvm-bridge/src/errors.rs index 7c8957dba7..6e8c818430 100644 --- a/native/core/src/errors.rs +++ b/native/jvm-bridge/src/errors.rs @@ -37,8 +37,7 @@ use std::{ // lifetime checker won't let us. use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort}; -use crate::execution::operators::ExecutionError; -use datafusion_comet_spark_expr::SparkError; +use crate::spark_error::{SparkError, SparkErrorWithContext}; use jni::objects::{GlobalRef, JThrowable}; use jni::JNIEnv; use lazy_static::lazy_static; @@ -49,6 +48,34 @@ lazy_static! { static ref PANIC_BACKTRACE: Arc>> = Arc::new(Mutex::new(None)); } +/// Error returned during executing operators. +#[derive(thiserror::Error, Debug)] +pub enum ExecutionError { + /// Simple error + #[allow(dead_code)] + #[error("General execution error with reason: {0}.")] + GeneralError(String), + + /// Error when deserializing an operator. + #[error("Fail to deserialize to native operator with reason: {0}.")] + DeserializeError(String), + + /// Error when processing Arrow array. + #[error("Fail to process Arrow array with reason: {0}.")] + ArrowError(String), + + /// DataFusion error + #[error("Error from DataFusion: {0}.")] + DataFusionError(String), + + #[error("{class}: {msg}")] + JavaException { + class: String, + msg: String, + throwable: GlobalRef, + }, +} + #[derive(thiserror::Error, Debug)] pub enum CometError { #[error("Configuration Error: {0}")] @@ -215,6 +242,66 @@ impl From for ExecutionError { } } +impl From for ExpressionError { + fn from(error: prost::DecodeError) -> ExpressionError { + ExpressionError::Deserialize(error.to_string()) + } +} + +impl From for ExpressionError { + fn from(error: prost::UnknownEnumValue) -> ExpressionError { + ExpressionError::Deserialize(error.to_string()) + } +} + +impl From for ExecutionError { + fn from(error: prost::DecodeError) -> ExecutionError { + ExecutionError::DeserializeError(error.to_string()) + } +} + +impl From for ExecutionError { + fn from(error: prost::UnknownEnumValue) -> ExecutionError { + ExecutionError::DeserializeError(error.to_string()) + } +} + +impl From for ExecutionError { + fn from(error: ArrowError) -> ExecutionError { + ExecutionError::ArrowError(error.to_string()) + } +} + +impl From for ExpressionError { + fn from(error: ArrowError) -> ExpressionError { + ExpressionError::ArrowError(error.to_string()) + } +} + +impl From for ArrowError { + fn from(error: ExpressionError) -> ArrowError { + ArrowError::ComputeError(error.to_string()) + } +} + +impl From for ExecutionError { + fn from(value: DataFusionError) -> Self { + ExecutionError::DataFusionError(value.message().to_string()) + } +} + +impl From for DataFusionError { + fn from(value: ExecutionError) -> Self { + DataFusionError::Execution(value.to_string()) + } +} + +impl From for DataFusionError { + fn from(value: ExpressionError) -> Self { + DataFusionError::Execution(value.to_string()) + } +} + impl jni::errors::ToException for CometError { fn to_exception(&self) -> Exception { match self { @@ -280,8 +367,9 @@ pub type CometResult = result::Result; // ---------------------------------------------------------------------- // Convenient macros for different errors +#[macro_export] macro_rules! general_err { - ($fmt:expr, $($args:expr),*) => (crate::CometError::from(parquet::errors::ParquetError::General(format!($fmt, $($args),*)))); + ($fmt:expr, $($args:expr),*) => ($crate::errors::CometError::from(parquet::errors::ParquetError::General(format!($fmt, $($args),*)))); } /// Returns the "default value" for a type. This is used for JNI code in order to facilitate @@ -401,9 +489,7 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option { // Try SparkErrorWithContext first (includes context) - if let Some(spark_error_with_ctx) = - e.downcast_ref::() - { + if let Some(spark_error_with_ctx) = e.downcast_ref::() { let json_message = spark_error_with_ctx.to_json(); env.throw_new( "org/apache/comet/exceptions/CometQueryExecutionException", @@ -463,6 +549,7 @@ enum StacktraceError { #[error("Unable to initialize backtrace regex: {0}")] Regex(#[from] regex::Error), #[error("Required field missing: {0}")] + #[allow(non_camel_case_types)] Required_Field(String), #[error("Unable to format stacktrace element: {0}")] Element(#[from] std::fmt::Error), diff --git a/native/core/src/jvm_bridge/batch_iterator.rs b/native/jvm-bridge/src/jvm_bridge/batch_iterator.rs similarity index 100% rename from native/core/src/jvm_bridge/batch_iterator.rs rename to native/jvm-bridge/src/jvm_bridge/batch_iterator.rs diff --git a/native/core/src/jvm_bridge/comet_exec.rs b/native/jvm-bridge/src/jvm_bridge/comet_exec.rs similarity index 100% rename from native/core/src/jvm_bridge/comet_exec.rs rename to native/jvm-bridge/src/jvm_bridge/comet_exec.rs diff --git a/native/core/src/jvm_bridge/comet_metric_node.rs b/native/jvm-bridge/src/jvm_bridge/comet_metric_node.rs similarity index 100% rename from native/core/src/jvm_bridge/comet_metric_node.rs rename to native/jvm-bridge/src/jvm_bridge/comet_metric_node.rs diff --git a/native/core/src/jvm_bridge/comet_task_memory_manager.rs b/native/jvm-bridge/src/jvm_bridge/comet_task_memory_manager.rs similarity index 100% rename from native/core/src/jvm_bridge/comet_task_memory_manager.rs rename to native/jvm-bridge/src/jvm_bridge/comet_task_memory_manager.rs diff --git a/native/core/src/jvm_bridge/mod.rs b/native/jvm-bridge/src/jvm_bridge/mod.rs similarity index 94% rename from native/core/src/jvm_bridge/mod.rs rename to native/jvm-bridge/src/jvm_bridge/mod.rs index 00fe7b33c3..19389ceeb5 100644 --- a/native/core/src/jvm_bridge/mod.rs +++ b/native/jvm-bridge/src/jvm_bridge/mod.rs @@ -40,6 +40,7 @@ macro_rules! jni_map_error { } /// Macro for converting Rust types to JNI types. +#[macro_export] macro_rules! jvalues { ($($args:expr,)* $(,)?) => {{ &[$(jni::objects::JValue::from($args).as_jni()),*] as &[jni::sys::jvalue] @@ -53,6 +54,7 @@ macro_rules! jvalues { /// metric_node is the Java object on which the method is called. /// add is the method name. /// jname and value are the arguments. +#[macro_export] macro_rules! jni_call { ($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ let method_id = paste::paste! { @@ -61,7 +63,7 @@ macro_rules! jni_call { let ret_type = paste::paste! { $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] }.clone(); - let args = $crate::jvm_bridge::jvalues!($($args,)*); + let args = $crate::jvalues!($($args,)*); // Call the JVM method and obtain the returned value let ret = $env.call_method_unchecked($obj, method_id, ret_type, args); @@ -70,13 +72,14 @@ macro_rules! jni_call { let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? { Err(exception.into()) } else { - $crate::jvm_bridge::jni_map_error!($env, ret) + $crate::jni_map_error!($env, ret) }; - result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result))) + result.and_then(|result| $crate::jni_map_error!($env, <$ret>::try_from(result))) }} } +#[macro_export] macro_rules! jni_static_call { ($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ let clazz = &paste::paste! { @@ -88,7 +91,7 @@ macro_rules! jni_static_call { let ret_type = paste::paste! { $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] }.clone(); - let args = $crate::jvm_bridge::jvalues!($($args,)*); + let args = $crate::jvalues!($($args,)*); // Call the JVM static method and obtain the returned value let ret = $env.call_static_method_unchecked(clazz, method_id, ret_type, args); @@ -97,13 +100,21 @@ macro_rules! jni_static_call { let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? { Err(exception.into()) } else { - $crate::jvm_bridge::jni_map_error!($env, ret) + $crate::jni_map_error!($env, ret) }; - result.and_then(|result| $crate::jvm_bridge::jni_map_error!($env, <$ret>::try_from(result))) + result.and_then(|result| $crate::jni_map_error!($env, <$ret>::try_from(result))) }} } +/// Macro for creating a new global reference. +#[macro_export] +macro_rules! jni_new_global_ref { + ($env:expr, $obj:expr) => {{ + $crate::jni_map_error!($env, $env.new_global_ref($obj)) + }}; +} + /// Wrapper for JString. Because we cannot implement `TryFrom` trait for `JString` as they /// are defined in different crates. pub struct StringWrapper<'a> { @@ -156,19 +167,6 @@ impl<'a> TryFrom> for BinaryWrapper<'a> { } } -/// Macro for creating a new global reference. -macro_rules! jni_new_global_ref { - ($env:expr, $obj:expr) => {{ - $crate::jni_map_error!($env, $env.new_global_ref($obj)) - }}; -} - -pub(crate) use jni_call; -pub(crate) use jni_map_error; -pub(crate) use jni_new_global_ref; -pub(crate) use jni_static_call; -pub(crate) use jvalues; - mod comet_exec; pub use comet_exec::*; mod batch_iterator; @@ -287,7 +285,7 @@ impl JVMClasses<'_> { } } -pub(crate) fn check_exception(env: &mut JNIEnv) -> CometResult> { +pub fn check_exception(env: &mut JNIEnv) -> CometResult> { let result = if env.exception_check()? { let exception = env.exception_occurred()?; env.exception_clear()?; @@ -380,10 +378,7 @@ fn get_throwable_message( /// this converts it into a `CometError::JavaException` with the exception class name /// and exception message. This error can then be populated to the JVM side to let /// users know the cause of the native side error. -pub(crate) fn convert_exception( - env: &mut JNIEnv, - throwable: &JThrowable, -) -> CometResult { +pub fn convert_exception(env: &mut JNIEnv, throwable: &JThrowable) -> CometResult { let cache = JVMClasses::get(); let exception_class_name_str = get_throwable_class_name(env, cache, throwable)?; let message_str = get_throwable_message(env, cache, throwable)?; diff --git a/native/jvm-bridge/src/lib.rs b/native/jvm-bridge/src/lib.rs new file mode 100644 index 0000000000..d6a92af243 --- /dev/null +++ b/native/jvm-bridge/src/lib.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! JVM bridge and error types for Apache DataFusion Comet. +//! +//! This crate provides the JNI/JVM interaction layer and common error types +//! used across Comet's native Rust crates. + +#![allow(clippy::result_large_err)] + +use jni::JavaVM; +use once_cell::sync::OnceCell; + +pub mod errors; +pub mod jvm_bridge; +pub mod query_context; +pub mod spark_error; + +pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; +pub use spark_error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; + +/// Global reference to the Java VM, initialized during native library setup. +pub static JAVA_VM: OnceCell = OnceCell::new(); diff --git a/native/jvm-bridge/src/query_context.rs b/native/jvm-bridge/src/query_context.rs new file mode 100644 index 0000000000..e6591135e0 --- /dev/null +++ b/native/jvm-bridge/src/query_context.rs @@ -0,0 +1,402 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Query execution context for error reporting +//! +//! This module provides QueryContext which mirrors Spark's SQLQueryContext +//! for providing SQL text, line/position information, and error location +//! pointers in exception messages. + +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// Based on Spark's SQLQueryContext for error reporting. +/// +/// Contains information about where an error occurred in a SQL query, +/// including the full SQL text, line/column positions, and object context. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct QueryContext { + /// Full SQL query text + #[serde(rename = "sqlText")] + pub sql_text: Arc, + + /// Start offset in SQL text (0-based, character index) + #[serde(rename = "startIndex")] + pub start_index: i32, + + /// Stop offset in SQL text (0-based, character index, inclusive) + #[serde(rename = "stopIndex")] + pub stop_index: i32, + + /// Object type (e.g., "VIEW", "Project", "Filter") + #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] + pub object_type: Option, + + /// Object name (e.g., view name, column name) + #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] + pub object_name: Option, + + /// Line number in SQL query (1-based) + pub line: i32, + + /// Column position within the line (0-based) + #[serde(rename = "startPosition")] + pub start_position: i32, +} + +impl QueryContext { + #[allow(clippy::too_many_arguments)] + pub fn new( + sql_text: String, + start_index: i32, + stop_index: i32, + object_type: Option, + object_name: Option, + line: i32, + start_position: i32, + ) -> Self { + Self { + sql_text: Arc::new(sql_text), + start_index, + stop_index, + object_type, + object_name, + line, + start_position, + } + } + + /// Convert a character index to a byte offset in the SQL text. + /// Returns None if the character index is out of range. + fn char_index_to_byte_offset(&self, char_index: usize) -> Option { + self.sql_text + .char_indices() + .nth(char_index) + .map(|(byte_offset, _)| byte_offset) + } + + /// Generate a summary string showing SQL fragment with error location. + /// (From SQLQueryContext.summary) + /// + /// Format example: + /// ```text + /// == SQL of VIEW v1 (line 1, position 8) == + /// SELECT a/b FROM t + /// ^^^ + /// ``` + pub fn format_summary(&self) -> String { + let start_char = self.start_index.max(0) as usize; + // stop_index is inclusive; fragment covers [start, stop] + let stop_char = (self.stop_index + 1).max(0) as usize; + + let fragment = match ( + self.char_index_to_byte_offset(start_char), + // stop_char may equal sql_text.chars().count() (one past the end) + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], + _ => "", + }; + + // Build the header line + let mut summary = String::from("== SQL"); + + if let Some(obj_type) = &self.object_type { + if !obj_type.is_empty() { + summary.push_str(" of "); + summary.push_str(obj_type); + + if let Some(obj_name) = &self.object_name { + if !obj_name.is_empty() { + summary.push(' '); + summary.push_str(obj_name); + } + } + } + } + + summary.push_str(&format!( + " (line {}, position {}) ==\n", + self.line, + self.start_position + 1 // Convert 0-based to 1-based for display + )); + + // Add the SQL text with fragment highlighted + summary.push_str(&self.sql_text); + summary.push('\n'); + + // Add caret pointer + let caret_position = self.start_position.max(0) as usize; + summary.push_str(&" ".repeat(caret_position)); + // fragment.chars().count() gives the correct display width for non-ASCII + summary.push_str(&"^".repeat(fragment.chars().count().max(1))); + + summary + } + + /// Returns the SQL fragment that caused the error. + pub fn fragment(&self) -> String { + let start_char = self.start_index.max(0) as usize; + let stop_char = (self.stop_index + 1).max(0) as usize; + + match ( + self.char_index_to_byte_offset(start_char), + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), + _ => String::new(), + } + } +} + +use std::collections::HashMap; +use std::sync::RwLock; + +/// Map that stores QueryContext information for expressions during execution. +/// +/// This map is populated during plan deserialization and accessed +/// during error creation to attach SQL context to exceptions. +#[derive(Debug)] +pub struct QueryContextMap { + /// Map from expression ID to QueryContext + contexts: RwLock>>, +} + +impl QueryContextMap { + pub fn new() -> Self { + Self { + contexts: RwLock::new(HashMap::new()), + } + } + + /// Register a QueryContext for an expression ID. + /// + /// If the expression ID already exists, it will be replaced. + /// + /// # Arguments + /// * `expr_id` - Unique expression identifier from protobuf + /// * `context` - QueryContext containing SQL text and position info + pub fn register(&self, expr_id: u64, context: QueryContext) { + let mut contexts = self.contexts.write().unwrap(); + contexts.insert(expr_id, Arc::new(context)); + } + + /// Get the QueryContext for an expression ID. + /// + /// Returns None if no context is registered for this expression. + /// + /// # Arguments + /// * `expr_id` - Expression identifier to look up + pub fn get(&self, expr_id: u64) -> Option> { + let contexts = self.contexts.read().unwrap(); + contexts.get(&expr_id).cloned() + } + + /// Clear all registered contexts. + /// + /// This is typically called after plan execution completes to free memory. + pub fn clear(&self) { + let mut contexts = self.contexts.write().unwrap(); + contexts.clear(); + } + + /// Return the number of registered contexts (for debugging/testing) + pub fn len(&self) -> usize { + let contexts = self.contexts.read().unwrap(); + contexts.len() + } + + /// Check if the map is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Default for QueryContextMap { + fn default() -> Self { + Self::new() + } +} + +/// Create a new session-scoped QueryContextMap. +/// +/// This should be called once per SessionContext during plan creation +/// and passed to expressions that need query context for error reporting. +pub fn create_query_context_map() -> Arc { + Arc::new(QueryContextMap::new()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query_context_creation() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); + assert_eq!(ctx.start_index, 7); + assert_eq!(ctx.stop_index, 9); + assert_eq!(ctx.object_type, Some("Divide".to_string())); + assert_eq!(ctx.object_name, Some("a/b".to_string())); + assert_eq!(ctx.line, 1); + assert_eq!(ctx.start_position, 7); + } + + #[test] + fn test_query_context_serialization() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + let json = serde_json::to_string(&ctx).unwrap(); + let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); + + assert_eq!(ctx, deserialized); + } + + #[test] + fn test_format_summary() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("VIEW".to_string()), + Some("v1".to_string()), + 1, + 7, + ); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + assert!(summary.contains("^^^")); // Three carets for "a/b" + } + + #[test] + fn test_format_summary_without_object() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + } + + #[test] + fn test_fragment() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + assert_eq!(ctx.fragment(), "a/b"); + } + + #[test] + fn test_arc_string_sharing() { + let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let ctx2 = ctx1.clone(); + + // Arc should share the same allocation + assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); + } + + #[test] + fn test_json_with_optional_fields() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let json = serde_json::to_string(&ctx).unwrap(); + + // Should not serialize objectType and objectName when None + assert!(!json.contains("objectType")); + assert!(!json.contains("objectName")); + } + + #[test] + fn test_map_register_and_get() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx.clone()); + + let retrieved = map.get(1).unwrap(); + assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); + assert_eq!(retrieved.start_index, 7); + } + + #[test] + fn test_map_get_nonexistent() { + let map = QueryContextMap::new(); + assert!(map.get(999).is_none()); + } + + #[test] + fn test_map_clear() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx); + assert_eq!(map.len(), 1); + + map.clear(); + assert_eq!(map.len(), 0); + assert!(map.is_empty()); + } + + // Verify that fragment() and format_summary() correctly handle SQL text that + // contains multi-byte characters + + #[test] + fn test_fragment_non_ascii_accented() { + // "é" is a 2-byte UTF-8 sequence (U+00E9). + // SQL: "SELECT café FROM t" + // 0123456789... + // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. + // start_index=7, stop_index=10 should yield "café" + let sql = "SELECT café FROM t".to_string(); + let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); + assert_eq!(ctx.fragment(), "café"); + } +} diff --git a/native/jvm-bridge/src/spark_error.rs b/native/jvm-bridge/src/spark_error.rs new file mode 100644 index 0000000000..ae3b5c0eda --- /dev/null +++ b/native/jvm-bridge/src/spark_error.rs @@ -0,0 +1,869 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::error::ArrowError; +use datafusion::common::DataFusionError; +use std::sync::Arc; + +#[derive(thiserror::Error, Debug, Clone)] +pub enum SparkError { + // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastInvalidValue { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + + #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")] + NumericOutOfRange { value: String }, + + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastOverFlow { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] + CannotParseDecimal, + + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntegralDivideOverflow, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DecimalSumOverflow, + + #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DivideByZero, + + #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + RemainderByZero, + + #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] + IntervalDividedByZero, + + #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + BinaryArithmeticOverflow { + value1: String, + symbol: String, + value2: String, + function_name: String, + }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithSuggestion { function_name: String }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithoutSuggestion, + + #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] + DatetimeOverflow, + + #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidArrayIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidElementAtIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] + InvalidBitmapPosition { + bit_position: i64, + bitmap_num_bytes: i64, + bitmap_num_bits: i64, + }, + + #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] + InvalidIndexOfZero, + + #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] + DuplicatedMapKey { key: String }, + + #[error("[NULL_MAP_KEY] Cannot use null as map key.")] + NullMapKey, + + #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] + MapKeyValueDiffSizes, + + #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] + ExceedMapSizeLimit { size: i32, max_size: i32 }, + + #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] + CollectionSizeLimitExceeded { + num_elements: i64, + max_elements: i64, + }, + + #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] + NotNullAssertViolation { field_name: String }, + + #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] + ValueIsNull { field_name: String, row_index: i32 }, + + #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] + CannotParseTimestamp { + message: String, + suggested_func: String, + }, + + #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidFractionOfSecond { value: f64 }, + + #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] + InvalidUtf8String { hex_string: String }, + + #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] + UnexpectedPositiveValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] + UnexpectedNegativeValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] + InvalidRegexGroupIndex { + function_name: String, + group_count: i32, + group_index: i32, + }, + + #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] + DatatypeCannotOrder { data_type: String }, + + #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] + ScalarSubqueryTooManyRows, + + #[error("ArrowError: {0}.")] + Arrow(Arc), + + #[error("InternalError: {0}.")] + Internal(String), +} + +impl SparkError { + /// Serialize this error to JSON format for JNI transfer + pub fn to_json(&self) -> String { + let error_class = self.error_class().unwrap_or(""); + + // Create a JSON structure with errorType, errorClass, and params + match serde_json::to_string(&serde_json::json!({ + "errorType": self.error_type_name(), + "errorClass": error_class, + "params": self.params_as_json(), + })) { + Ok(json) => json, + Err(e) => { + // Fallback if serialization fails + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + } + } + } + + /// Get the error type name for JSON serialization + fn error_type_name(&self) -> &'static str { + match self { + SparkError::CastInvalidValue { .. } => "CastInvalidValue", + SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", + SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", + SparkError::CastOverFlow { .. } => "CastOverFlow", + SparkError::CannotParseDecimal => "CannotParseDecimal", + SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", + SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", + SparkError::DecimalSumOverflow => "DecimalSumOverflow", + SparkError::DivideByZero => "DivideByZero", + SparkError::RemainderByZero => "RemainderByZero", + SparkError::IntervalDividedByZero => "IntervalDividedByZero", + SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + "IntervalArithmeticOverflowWithSuggestion" + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + "IntervalArithmeticOverflowWithoutSuggestion" + } + SparkError::DatetimeOverflow => "DatetimeOverflow", + SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", + SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", + SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", + SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", + SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", + SparkError::NullMapKey => "NullMapKey", + SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", + SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", + SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", + SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", + SparkError::ValueIsNull { .. } => "ValueIsNull", + SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", + SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", + SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", + SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", + SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", + SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", + SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", + SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", + SparkError::Arrow(_) => "Arrow", + SparkError::Internal(_) => "Internal", + } + } + + /// Extract parameters as JSON value + fn params_as_json(&self) -> serde_json::Value { + match self { + SparkError::CastInvalidValue { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::NumericValueOutOfRange { + value, + precision, + scale, + } => { + serde_json::json!({ + "value": value, + "precision": precision, + "scale": scale, + }) + } + SparkError::NumericOutOfRange { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::CastOverFlow { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::ArithmeticOverflow { from_type } => { + serde_json::json!({ + "fromType": from_type, + }) + } + SparkError::BinaryArithmeticOverflow { + value1, + symbol, + value2, + function_name, + } => { + serde_json::json!({ + "value1": value1, + "symbol": symbol, + "value2": value2, + "functionName": function_name, + }) + } + SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { + serde_json::json!({ + "functionName": function_name, + }) + } + SparkError::InvalidArrayIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidElementAtIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidBitmapPosition { + bit_position, + bitmap_num_bytes, + bitmap_num_bits, + } => { + serde_json::json!({ + "bitPosition": bit_position, + "bitmapNumBytes": bitmap_num_bytes, + "bitmapNumBits": bitmap_num_bits, + }) + } + SparkError::DuplicatedMapKey { key } => { + serde_json::json!({ + "key": key, + }) + } + SparkError::ExceedMapSizeLimit { size, max_size } => { + serde_json::json!({ + "size": size, + "maxSize": max_size, + }) + } + SparkError::CollectionSizeLimitExceeded { + num_elements, + max_elements, + } => { + serde_json::json!({ + "numElements": num_elements, + "maxElements": max_elements, + }) + } + SparkError::NotNullAssertViolation { field_name } => { + serde_json::json!({ + "fieldName": field_name, + }) + } + SparkError::ValueIsNull { + field_name, + row_index, + } => { + serde_json::json!({ + "fieldName": field_name, + "rowIndex": row_index, + }) + } + SparkError::CannotParseTimestamp { + message, + suggested_func, + } => { + serde_json::json!({ + "message": message, + "suggestedFunc": suggested_func, + }) + } + SparkError::InvalidFractionOfSecond { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::InvalidUtf8String { hex_string } => { + serde_json::json!({ + "hexString": hex_string, + }) + } + SparkError::UnexpectedPositiveValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::UnexpectedNegativeValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::InvalidRegexGroupIndex { + function_name, + group_count, + group_index, + } => { + serde_json::json!({ + "functionName": function_name, + "groupCount": group_count, + "groupIndex": group_index, + }) + } + SparkError::DatatypeCannotOrder { data_type } => { + serde_json::json!({ + "dataType": data_type, + }) + } + SparkError::Arrow(e) => { + serde_json::json!({ + "message": e.to_string(), + }) + } + SparkError::Internal(msg) => { + serde_json::json!({ + "message": msg, + }) + } + // Simple errors with no parameters + _ => serde_json::json!({}), + } + } + + /// Returns the appropriate Spark exception class for this error + pub fn exception_class(&self) -> &'static str { + match self { + // ArithmeticException + SparkError::DivideByZero + | SparkError::RemainderByZero + | SparkError::IntervalDividedByZero + | SparkError::NumericValueOutOfRange { .. } + | SparkError::NumericOutOfRange { .. } // Comet-specific extension + | SparkError::ArithmeticOverflow { .. } + | SparkError::IntegralDivideOverflow + | SparkError::DecimalSumOverflow + | SparkError::BinaryArithmeticOverflow { .. } + | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } + | SparkError::IntervalArithmeticOverflowWithoutSuggestion + | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", + + // CastOverflow gets special handling with CastOverflowException + SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", + + // NumberFormatException (for cast invalid input errors) + SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", + + // ArrayIndexOutOfBoundsException + SparkError::InvalidArrayIndex { .. } + | SparkError::InvalidElementAtIndex { .. } + | SparkError::InvalidBitmapPosition { .. } + | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", + + // RuntimeException + SparkError::CannotParseDecimal + | SparkError::DuplicatedMapKey { .. } + | SparkError::NullMapKey + | SparkError::MapKeyValueDiffSizes + | SparkError::ExceedMapSizeLimit { .. } + | SparkError::CollectionSizeLimitExceeded { .. } + | SparkError::NotNullAssertViolation { .. } + | SparkError::ValueIsNull { .. } // Comet-specific extension + | SparkError::UnexpectedPositiveValue { .. } + | SparkError::UnexpectedNegativeValue { .. } + | SparkError::InvalidRegexGroupIndex { .. } + | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", + + // DateTimeException + SparkError::CannotParseTimestamp { .. } + | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", + + // IllegalArgumentException + SparkError::DatatypeCannotOrder { .. } + | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", + + // Generic errors + SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", + } + } + + /// Returns the Spark error class code for this error + pub fn error_class(&self) -> Option<&'static str> { + match self { + // Cast errors + SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), + SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), + SparkError::NumericValueOutOfRange { .. } => { + Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") + } + SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), + SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), + + // Arithmetic errors + SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), + SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), + SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), + SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), + SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), + + // Array index errors + SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), + SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), + SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), + SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), + + // Map/Collection errors + SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), + SparkError::NullMapKey => Some("NULL_MAP_KEY"), + SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), + SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), + SparkError::CollectionSizeLimitExceeded { .. } => { + Some("COLLECTION_SIZE_LIMIT_EXCEEDED") + } + + // Null validation errors + SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), + SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), + + // DateTime errors + SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), + SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), + + // String/UTF8 errors + SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), + + // Function parameter errors + SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), + SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), + + // Regex errors + SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), + + // Unsupported operation errors + SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), + + // Subquery errors + SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), + + // Generic errors (no error class) + SparkError::Arrow(_) | SparkError::Internal(_) => None, + } + } +} + +/// Convert decimal overflow to SparkError::NumericValueOutOfRange. +/// +/// Creates the appropriate SparkError when a decimal value exceeds the precision limit for Decimal128 storage. +/// +/// # Arguments +/// * `value` - The i128 decimal value that overflowed +/// * `precision` - The target precision +/// * `scale` - The scale of the decimal +/// +/// # Returns +/// SparkError::NumericValueOutOfRange with the value, precision, and scale +pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { + SparkError::NumericValueOutOfRange { + value: value.to_string(), + precision, + scale, + } +} + +pub type SparkResult = Result; + +/// Wrapper that adds QueryContext to SparkError +/// +/// This allows attaching SQL context information (query text, line/position, object name) to errors +#[derive(Debug, Clone)] +pub struct SparkErrorWithContext { + /// The underlying SparkError + pub error: SparkError, + /// Optional QueryContext for SQL location information + pub context: Option>, +} + +impl SparkErrorWithContext { + /// Create a SparkErrorWithContext without context + pub fn new(error: SparkError) -> Self { + Self { + error, + context: None, + } + } + + /// Create a SparkErrorWithContext with QueryContext + pub fn with_context(error: SparkError, context: Arc) -> Self { + Self { + error, + context: Some(context), + } + } + + /// Serialize to JSON including optional context field + /// + /// JSON structure: + /// ```json + /// { + /// "errorType": "DivideByZero", + /// "errorClass": "DIVIDE_BY_ZERO", + /// "params": {}, + /// "context": { + /// "sqlText": "SELECT a/b FROM t", + /// "startIndex": 7, + /// "stopIndex": 9, + /// "line": 1, + /// "startPosition": 7 + /// }, + /// "summary": "== SQL (line 1, position 8) ==\n..." + /// } + /// ``` + pub fn to_json(&self) -> String { + let mut json_obj = serde_json::json!({ + "errorType": self.error.error_type_name(), + "errorClass": self.error.error_class().unwrap_or(""), + "params": self.error.params_as_json(), + }); + + if let Some(ctx) = &self.context { + // Serialize context fields + json_obj["context"] = serde_json::json!({ + "sqlText": ctx.sql_text.as_str(), + "startIndex": ctx.start_index, + "stopIndex": ctx.stop_index, + "objectType": ctx.object_type, + "objectName": ctx.object_name, + "line": ctx.line, + "startPosition": ctx.start_position, + }); + + // Add formatted summary + json_obj["summary"] = serde_json::json!(ctx.format_summary()); + } + + serde_json::to_string(&json_obj).unwrap_or_else(|e| { + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + }) + } +} + +impl std::fmt::Display for SparkErrorWithContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.error)?; + if let Some(ctx) = &self.context { + write!(f, "\n{}", ctx.format_summary())?; + } + Ok(()) + } +} + +impl std::error::Error for SparkErrorWithContext {} + +impl From for SparkErrorWithContext { + fn from(error: SparkError) -> Self { + SparkErrorWithContext::new(error) + } +} + +impl From for DataFusionError { + fn from(value: SparkErrorWithContext) -> Self { + DataFusionError::External(Box::new(value)) + } +} + +impl From for SparkError { + fn from(value: ArrowError) -> Self { + SparkError::Arrow(Arc::new(value)) + } +} + +impl From for DataFusionError { + fn from(value: SparkError) -> Self { + DataFusionError::External(Box::new(value)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_divide_by_zero_json() { + let error = SparkError::DivideByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"DivideByZero\"")); + assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DivideByZero"); + assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); + } + + #[test] + fn test_remainder_by_zero_json() { + let error = SparkError::RemainderByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"RemainderByZero\"")); + assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); + } + + #[test] + fn test_binary_overflow_json() { + let error = SparkError::BinaryArithmeticOverflow { + value1: "32767".to_string(), + symbol: "+".to_string(), + value2: "1".to_string(), + function_name: "try_add".to_string(), + }; + let json = error.to_json(); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); + assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); + assert_eq!(parsed["params"]["value1"], "32767"); + assert_eq!(parsed["params"]["symbol"], "+"); + assert_eq!(parsed["params"]["value2"], "1"); + assert_eq!(parsed["params"]["functionName"], "try_add"); + } + + #[test] + fn test_invalid_array_index_json() { + let error = SparkError::InvalidArrayIndex { + index_value: 10, + array_size: 3, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "InvalidArrayIndex"); + assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); + assert_eq!(parsed["params"]["indexValue"], 10); + assert_eq!(parsed["params"]["arraySize"], 3); + } + + #[test] + fn test_numeric_value_out_of_range_json() { + let error = SparkError::NumericValueOutOfRange { + value: "999.99".to_string(), + precision: 5, + scale: 2, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); + assert_eq!( + parsed["errorClass"], + "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" + ); + assert_eq!(parsed["params"]["value"], "999.99"); + assert_eq!(parsed["params"]["precision"], 5); + assert_eq!(parsed["params"]["scale"], 2); + } + + #[test] + fn test_cast_invalid_value_json() { + let error = SparkError::CastInvalidValue { + value: "abc".to_string(), + from_type: "STRING".to_string(), + to_type: "INT".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "CastInvalidValue"); + assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); + assert_eq!(parsed["params"]["value"], "abc"); + assert_eq!(parsed["params"]["fromType"], "STRING"); + assert_eq!(parsed["params"]["toType"], "INT"); + } + + #[test] + fn test_duplicated_map_key_json() { + let error = SparkError::DuplicatedMapKey { + key: "duplicate_key".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DuplicatedMapKey"); + assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); + assert_eq!(parsed["params"]["key"], "duplicate_key"); + } + + #[test] + fn test_null_map_key_json() { + let error = SparkError::NullMapKey; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NullMapKey"); + assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); + // Params should be an empty object + assert_eq!(parsed["params"], serde_json::json!({})); + } + + #[test] + fn test_error_class_mapping() { + // Test that error_class() returns the correct error class + assert_eq!( + SparkError::DivideByZero.error_class(), + Some("DIVIDE_BY_ZERO") + ); + assert_eq!( + SparkError::RemainderByZero.error_class(), + Some("REMAINDER_BY_ZERO") + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .error_class(), + Some("INVALID_ARRAY_INDEX") + ); + assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); + } + + #[test] + fn test_exception_class_mapping() { + // Test that exception_class() returns the correct Java exception class + assert_eq!( + SparkError::DivideByZero.exception_class(), + "org/apache/spark/SparkArithmeticException" + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .exception_class(), + "org/apache/spark/SparkArrayIndexOutOfBoundsException" + ); + assert_eq!( + SparkError::NullMapKey.exception_class(), + "org/apache/spark/SparkRuntimeException" + ); + } +} diff --git a/native/core/testdata/backtrace.txt b/native/jvm-bridge/testdata/backtrace.txt similarity index 100% rename from native/core/testdata/backtrace.txt rename to native/jvm-bridge/testdata/backtrace.txt diff --git a/native/core/testdata/stacktrace.txt b/native/jvm-bridge/testdata/stacktrace.txt similarity index 100% rename from native/core/testdata/stacktrace.txt rename to native/jvm-bridge/testdata/stacktrace.txt diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 9f08e480f2..651b11feb1 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -30,6 +30,7 @@ edition = { workspace = true } arrow = { workspace = true } chrono = { workspace = true } datafusion = { workspace = true } +datafusion-comet-jvm-bridge = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs index ae3b5c0eda..c4c97c1b41 100644 --- a/native/spark-expr/src/error.rs +++ b/native/spark-expr/src/error.rs @@ -15,855 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::error::ArrowError; -use datafusion::common::DataFusionError; -use std::sync::Arc; - -#[derive(thiserror::Error, Debug, Clone)] -pub enum SparkError { - // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet - #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - because it is malformed. Correct the value as per the syntax, or change its target type. \ - Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastInvalidValue { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] - NumericValueOutOfRange { - value: String, - precision: u8, - scale: i8, - }, - - #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")] - NumericOutOfRange { value: String }, - - #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastOverFlow { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] - CannotParseDecimal, - - #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - ArithmeticOverflow { from_type: String }, - - #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntegralDivideOverflow, - - #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - DecimalSumOverflow, - - #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - DivideByZero, - - #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - RemainderByZero, - - #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] - IntervalDividedByZero, - - #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - BinaryArithmeticOverflow { - value1: String, - symbol: String, - value2: String, - function_name: String, - }, - - #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntervalArithmeticOverflowWithSuggestion { function_name: String }, - - #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntervalArithmeticOverflowWithoutSuggestion, - - #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] - DatetimeOverflow, - - #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidArrayIndex { index_value: i32, array_size: i32 }, - - #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidElementAtIndex { index_value: i32, array_size: i32 }, - - #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] - InvalidBitmapPosition { - bit_position: i64, - bitmap_num_bytes: i64, - bitmap_num_bits: i64, - }, - - #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] - InvalidIndexOfZero, - - #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] - DuplicatedMapKey { key: String }, - - #[error("[NULL_MAP_KEY] Cannot use null as map key.")] - NullMapKey, - - #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] - MapKeyValueDiffSizes, - - #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] - ExceedMapSizeLimit { size: i32, max_size: i32 }, - - #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] - CollectionSizeLimitExceeded { - num_elements: i64, - max_elements: i64, - }, - - #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] - NotNullAssertViolation { field_name: String }, - - #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] - ValueIsNull { field_name: String, row_index: i32 }, - - #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] - CannotParseTimestamp { - message: String, - suggested_func: String, - }, - - #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidFractionOfSecond { value: f64 }, - - #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] - InvalidUtf8String { hex_string: String }, - - #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] - UnexpectedPositiveValue { - parameter_name: String, - actual_value: i32, - }, - - #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] - UnexpectedNegativeValue { - parameter_name: String, - actual_value: i32, - }, - - #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] - InvalidRegexGroupIndex { - function_name: String, - group_count: i32, - group_index: i32, - }, - - #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] - DatatypeCannotOrder { data_type: String }, - - #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] - ScalarSubqueryTooManyRows, - - #[error("ArrowError: {0}.")] - Arrow(Arc), - - #[error("InternalError: {0}.")] - Internal(String), -} - -impl SparkError { - /// Serialize this error to JSON format for JNI transfer - pub fn to_json(&self) -> String { - let error_class = self.error_class().unwrap_or(""); - - // Create a JSON structure with errorType, errorClass, and params - match serde_json::to_string(&serde_json::json!({ - "errorType": self.error_type_name(), - "errorClass": error_class, - "params": self.params_as_json(), - })) { - Ok(json) => json, - Err(e) => { - // Fallback if serialization fails - format!( - "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", - e - ) - } - } - } - - /// Get the error type name for JSON serialization - fn error_type_name(&self) -> &'static str { - match self { - SparkError::CastInvalidValue { .. } => "CastInvalidValue", - SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", - SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", - SparkError::CastOverFlow { .. } => "CastOverFlow", - SparkError::CannotParseDecimal => "CannotParseDecimal", - SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", - SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", - SparkError::DecimalSumOverflow => "DecimalSumOverflow", - SparkError::DivideByZero => "DivideByZero", - SparkError::RemainderByZero => "RemainderByZero", - SparkError::IntervalDividedByZero => "IntervalDividedByZero", - SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", - SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { - "IntervalArithmeticOverflowWithSuggestion" - } - SparkError::IntervalArithmeticOverflowWithoutSuggestion => { - "IntervalArithmeticOverflowWithoutSuggestion" - } - SparkError::DatetimeOverflow => "DatetimeOverflow", - SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", - SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", - SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", - SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", - SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", - SparkError::NullMapKey => "NullMapKey", - SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", - SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", - SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", - SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", - SparkError::ValueIsNull { .. } => "ValueIsNull", - SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", - SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", - SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", - SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", - SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", - SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", - SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", - SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", - SparkError::Arrow(_) => "Arrow", - SparkError::Internal(_) => "Internal", - } - } - - /// Extract parameters as JSON value - fn params_as_json(&self) -> serde_json::Value { - match self { - SparkError::CastInvalidValue { - value, - from_type, - to_type, - } => { - serde_json::json!({ - "value": value, - "fromType": from_type, - "toType": to_type, - }) - } - SparkError::NumericValueOutOfRange { - value, - precision, - scale, - } => { - serde_json::json!({ - "value": value, - "precision": precision, - "scale": scale, - }) - } - SparkError::NumericOutOfRange { value } => { - serde_json::json!({ - "value": value, - }) - } - SparkError::CastOverFlow { - value, - from_type, - to_type, - } => { - serde_json::json!({ - "value": value, - "fromType": from_type, - "toType": to_type, - }) - } - SparkError::ArithmeticOverflow { from_type } => { - serde_json::json!({ - "fromType": from_type, - }) - } - SparkError::BinaryArithmeticOverflow { - value1, - symbol, - value2, - function_name, - } => { - serde_json::json!({ - "value1": value1, - "symbol": symbol, - "value2": value2, - "functionName": function_name, - }) - } - SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { - serde_json::json!({ - "functionName": function_name, - }) - } - SparkError::InvalidArrayIndex { - index_value, - array_size, - } => { - serde_json::json!({ - "indexValue": index_value, - "arraySize": array_size, - }) - } - SparkError::InvalidElementAtIndex { - index_value, - array_size, - } => { - serde_json::json!({ - "indexValue": index_value, - "arraySize": array_size, - }) - } - SparkError::InvalidBitmapPosition { - bit_position, - bitmap_num_bytes, - bitmap_num_bits, - } => { - serde_json::json!({ - "bitPosition": bit_position, - "bitmapNumBytes": bitmap_num_bytes, - "bitmapNumBits": bitmap_num_bits, - }) - } - SparkError::DuplicatedMapKey { key } => { - serde_json::json!({ - "key": key, - }) - } - SparkError::ExceedMapSizeLimit { size, max_size } => { - serde_json::json!({ - "size": size, - "maxSize": max_size, - }) - } - SparkError::CollectionSizeLimitExceeded { - num_elements, - max_elements, - } => { - serde_json::json!({ - "numElements": num_elements, - "maxElements": max_elements, - }) - } - SparkError::NotNullAssertViolation { field_name } => { - serde_json::json!({ - "fieldName": field_name, - }) - } - SparkError::ValueIsNull { - field_name, - row_index, - } => { - serde_json::json!({ - "fieldName": field_name, - "rowIndex": row_index, - }) - } - SparkError::CannotParseTimestamp { - message, - suggested_func, - } => { - serde_json::json!({ - "message": message, - "suggestedFunc": suggested_func, - }) - } - SparkError::InvalidFractionOfSecond { value } => { - serde_json::json!({ - "value": value, - }) - } - SparkError::InvalidUtf8String { hex_string } => { - serde_json::json!({ - "hexString": hex_string, - }) - } - SparkError::UnexpectedPositiveValue { - parameter_name, - actual_value, - } => { - serde_json::json!({ - "parameterName": parameter_name, - "actualValue": actual_value, - }) - } - SparkError::UnexpectedNegativeValue { - parameter_name, - actual_value, - } => { - serde_json::json!({ - "parameterName": parameter_name, - "actualValue": actual_value, - }) - } - SparkError::InvalidRegexGroupIndex { - function_name, - group_count, - group_index, - } => { - serde_json::json!({ - "functionName": function_name, - "groupCount": group_count, - "groupIndex": group_index, - }) - } - SparkError::DatatypeCannotOrder { data_type } => { - serde_json::json!({ - "dataType": data_type, - }) - } - SparkError::Arrow(e) => { - serde_json::json!({ - "message": e.to_string(), - }) - } - SparkError::Internal(msg) => { - serde_json::json!({ - "message": msg, - }) - } - // Simple errors with no parameters - _ => serde_json::json!({}), - } - } - - /// Returns the appropriate Spark exception class for this error - pub fn exception_class(&self) -> &'static str { - match self { - // ArithmeticException - SparkError::DivideByZero - | SparkError::RemainderByZero - | SparkError::IntervalDividedByZero - | SparkError::NumericValueOutOfRange { .. } - | SparkError::NumericOutOfRange { .. } // Comet-specific extension - | SparkError::ArithmeticOverflow { .. } - | SparkError::IntegralDivideOverflow - | SparkError::DecimalSumOverflow - | SparkError::BinaryArithmeticOverflow { .. } - | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } - | SparkError::IntervalArithmeticOverflowWithoutSuggestion - | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", - - // CastOverflow gets special handling with CastOverflowException - SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", - - // NumberFormatException (for cast invalid input errors) - SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", - - // ArrayIndexOutOfBoundsException - SparkError::InvalidArrayIndex { .. } - | SparkError::InvalidElementAtIndex { .. } - | SparkError::InvalidBitmapPosition { .. } - | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", - - // RuntimeException - SparkError::CannotParseDecimal - | SparkError::DuplicatedMapKey { .. } - | SparkError::NullMapKey - | SparkError::MapKeyValueDiffSizes - | SparkError::ExceedMapSizeLimit { .. } - | SparkError::CollectionSizeLimitExceeded { .. } - | SparkError::NotNullAssertViolation { .. } - | SparkError::ValueIsNull { .. } // Comet-specific extension - | SparkError::UnexpectedPositiveValue { .. } - | SparkError::UnexpectedNegativeValue { .. } - | SparkError::InvalidRegexGroupIndex { .. } - | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", - - // DateTimeException - SparkError::CannotParseTimestamp { .. } - | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", - - // IllegalArgumentException - SparkError::DatatypeCannotOrder { .. } - | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", - - // Generic errors - SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", - } - } - - /// Returns the Spark error class code for this error - pub fn error_class(&self) -> Option<&'static str> { - match self { - // Cast errors - SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), - SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), - SparkError::NumericValueOutOfRange { .. } => { - Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") - } - SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), - SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), - - // Arithmetic errors - SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), - SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), - SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), - SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), - SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), - SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), - SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), - SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { - Some("INTERVAL_ARITHMETIC_OVERFLOW") - } - SparkError::IntervalArithmeticOverflowWithoutSuggestion => { - Some("INTERVAL_ARITHMETIC_OVERFLOW") - } - SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), - - // Array index errors - SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), - SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), - SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), - SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), - - // Map/Collection errors - SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), - SparkError::NullMapKey => Some("NULL_MAP_KEY"), - SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), - SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), - SparkError::CollectionSizeLimitExceeded { .. } => { - Some("COLLECTION_SIZE_LIMIT_EXCEEDED") - } - - // Null validation errors - SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), - SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), - - // DateTime errors - SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), - SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), - - // String/UTF8 errors - SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), - - // Function parameter errors - SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), - SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), - - // Regex errors - SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), - - // Unsupported operation errors - SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), - - // Subquery errors - SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), - - // Generic errors (no error class) - SparkError::Arrow(_) | SparkError::Internal(_) => None, - } - } -} - -/// Convert decimal overflow to SparkError::NumericValueOutOfRange. -/// -/// Creates the appropriate SparkError when a decimal value exceeds the precision limit for Decimal128 storage. -/// -/// # Arguments -/// * `value` - The i128 decimal value that overflowed -/// * `precision` - The target precision -/// * `scale` - The scale of the decimal -/// -/// # Returns -/// SparkError::NumericValueOutOfRange with the value, precision, and scale -pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { - SparkError::NumericValueOutOfRange { - value: value.to_string(), - precision, - scale, - } -} - -pub type SparkResult = Result; - -/// Wrapper that adds QueryContext to SparkError -/// -/// This allows attaching SQL context information (query text, line/position, object name) to errors -#[derive(Debug, Clone)] -pub struct SparkErrorWithContext { - /// The underlying SparkError - pub error: SparkError, - /// Optional QueryContext for SQL location information - pub context: Option>, -} - -impl SparkErrorWithContext { - /// Create a SparkErrorWithContext without context - pub fn new(error: SparkError) -> Self { - Self { - error, - context: None, - } - } - - /// Create a SparkErrorWithContext with QueryContext - pub fn with_context(error: SparkError, context: Arc) -> Self { - Self { - error, - context: Some(context), - } - } - - /// Serialize to JSON including optional context field - /// - /// JSON structure: - /// ```json - /// { - /// "errorType": "DivideByZero", - /// "errorClass": "DIVIDE_BY_ZERO", - /// "params": {}, - /// "context": { - /// "sqlText": "SELECT a/b FROM t", - /// "startIndex": 7, - /// "stopIndex": 9, - /// "line": 1, - /// "startPosition": 7 - /// }, - /// "summary": "== SQL (line 1, position 8) ==\n..." - /// } - /// ``` - pub fn to_json(&self) -> String { - let mut json_obj = serde_json::json!({ - "errorType": self.error.error_type_name(), - "errorClass": self.error.error_class().unwrap_or(""), - "params": self.error.params_as_json(), - }); - - if let Some(ctx) = &self.context { - // Serialize context fields - json_obj["context"] = serde_json::json!({ - "sqlText": ctx.sql_text.as_str(), - "startIndex": ctx.start_index, - "stopIndex": ctx.stop_index, - "objectType": ctx.object_type, - "objectName": ctx.object_name, - "line": ctx.line, - "startPosition": ctx.start_position, - }); - - // Add formatted summary - json_obj["summary"] = serde_json::json!(ctx.format_summary()); - } - - serde_json::to_string(&json_obj).unwrap_or_else(|e| { - format!( - "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", - e - ) - }) - } -} - -impl std::fmt::Display for SparkErrorWithContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.error)?; - if let Some(ctx) = &self.context { - write!(f, "\n{}", ctx.format_summary())?; - } - Ok(()) - } -} - -impl std::error::Error for SparkErrorWithContext {} - -impl From for SparkErrorWithContext { - fn from(error: SparkError) -> Self { - SparkErrorWithContext::new(error) - } -} - -impl From for DataFusionError { - fn from(value: SparkErrorWithContext) -> Self { - DataFusionError::External(Box::new(value)) - } -} - -impl From for SparkError { - fn from(value: ArrowError) -> Self { - SparkError::Arrow(Arc::new(value)) - } -} - -impl From for DataFusionError { - fn from(value: SparkError) -> Self { - DataFusionError::External(Box::new(value)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_divide_by_zero_json() { - let error = SparkError::DivideByZero; - let json = error.to_json(); - - assert!(json.contains("\"errorType\":\"DivideByZero\"")); - assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); - - // Verify it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "DivideByZero"); - assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); - } - - #[test] - fn test_remainder_by_zero_json() { - let error = SparkError::RemainderByZero; - let json = error.to_json(); - - assert!(json.contains("\"errorType\":\"RemainderByZero\"")); - assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); - } - - #[test] - fn test_binary_overflow_json() { - let error = SparkError::BinaryArithmeticOverflow { - value1: "32767".to_string(), - symbol: "+".to_string(), - value2: "1".to_string(), - function_name: "try_add".to_string(), - }; - let json = error.to_json(); - - // Verify it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); - assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); - assert_eq!(parsed["params"]["value1"], "32767"); - assert_eq!(parsed["params"]["symbol"], "+"); - assert_eq!(parsed["params"]["value2"], "1"); - assert_eq!(parsed["params"]["functionName"], "try_add"); - } - - #[test] - fn test_invalid_array_index_json() { - let error = SparkError::InvalidArrayIndex { - index_value: 10, - array_size: 3, - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "InvalidArrayIndex"); - assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); - assert_eq!(parsed["params"]["indexValue"], 10); - assert_eq!(parsed["params"]["arraySize"], 3); - } - - #[test] - fn test_numeric_value_out_of_range_json() { - let error = SparkError::NumericValueOutOfRange { - value: "999.99".to_string(), - precision: 5, - scale: 2, - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); - assert_eq!( - parsed["errorClass"], - "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" - ); - assert_eq!(parsed["params"]["value"], "999.99"); - assert_eq!(parsed["params"]["precision"], 5); - assert_eq!(parsed["params"]["scale"], 2); - } - - #[test] - fn test_cast_invalid_value_json() { - let error = SparkError::CastInvalidValue { - value: "abc".to_string(), - from_type: "STRING".to_string(), - to_type: "INT".to_string(), - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "CastInvalidValue"); - assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); - assert_eq!(parsed["params"]["value"], "abc"); - assert_eq!(parsed["params"]["fromType"], "STRING"); - assert_eq!(parsed["params"]["toType"], "INT"); - } - - #[test] - fn test_duplicated_map_key_json() { - let error = SparkError::DuplicatedMapKey { - key: "duplicate_key".to_string(), - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "DuplicatedMapKey"); - assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); - assert_eq!(parsed["params"]["key"], "duplicate_key"); - } - - #[test] - fn test_null_map_key_json() { - let error = SparkError::NullMapKey; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "NullMapKey"); - assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); - // Params should be an empty object - assert_eq!(parsed["params"], serde_json::json!({})); - } - - #[test] - fn test_error_class_mapping() { - // Test that error_class() returns the correct error class - assert_eq!( - SparkError::DivideByZero.error_class(), - Some("DIVIDE_BY_ZERO") - ); - assert_eq!( - SparkError::RemainderByZero.error_class(), - Some("REMAINDER_BY_ZERO") - ); - assert_eq!( - SparkError::InvalidArrayIndex { - index_value: 0, - array_size: 0 - } - .error_class(), - Some("INVALID_ARRAY_INDEX") - ); - assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); - } - - #[test] - fn test_exception_class_mapping() { - // Test that exception_class() returns the correct Java exception class - assert_eq!( - SparkError::DivideByZero.exception_class(), - "org/apache/spark/SparkArithmeticException" - ); - assert_eq!( - SparkError::InvalidArrayIndex { - index_value: 0, - array_size: 0 - } - .exception_class(), - "org/apache/spark/SparkArrayIndexOutOfBoundsException" - ); - assert_eq!( - SparkError::NullMapKey.exception_class(), - "org/apache/spark/SparkRuntimeException" - ); - } -} +// Re-export SparkError types from jvm-bridge crate +pub use datafusion_comet_jvm_bridge::spark_error::{ + decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult, +}; diff --git a/native/spark-expr/src/query_context.rs b/native/spark-expr/src/query_context.rs index e6591135e0..e7ae1e01a9 100644 --- a/native/spark-expr/src/query_context.rs +++ b/native/spark-expr/src/query_context.rs @@ -15,388 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Query execution context for error reporting -//! -//! This module provides QueryContext which mirrors Spark's SQLQueryContext -//! for providing SQL text, line/position information, and error location -//! pointers in exception messages. - -use serde::{Deserialize, Serialize}; -use std::sync::Arc; - -/// Based on Spark's SQLQueryContext for error reporting. -/// -/// Contains information about where an error occurred in a SQL query, -/// including the full SQL text, line/column positions, and object context. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct QueryContext { - /// Full SQL query text - #[serde(rename = "sqlText")] - pub sql_text: Arc, - - /// Start offset in SQL text (0-based, character index) - #[serde(rename = "startIndex")] - pub start_index: i32, - - /// Stop offset in SQL text (0-based, character index, inclusive) - #[serde(rename = "stopIndex")] - pub stop_index: i32, - - /// Object type (e.g., "VIEW", "Project", "Filter") - #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] - pub object_type: Option, - - /// Object name (e.g., view name, column name) - #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] - pub object_name: Option, - - /// Line number in SQL query (1-based) - pub line: i32, - - /// Column position within the line (0-based) - #[serde(rename = "startPosition")] - pub start_position: i32, -} - -impl QueryContext { - #[allow(clippy::too_many_arguments)] - pub fn new( - sql_text: String, - start_index: i32, - stop_index: i32, - object_type: Option, - object_name: Option, - line: i32, - start_position: i32, - ) -> Self { - Self { - sql_text: Arc::new(sql_text), - start_index, - stop_index, - object_type, - object_name, - line, - start_position, - } - } - - /// Convert a character index to a byte offset in the SQL text. - /// Returns None if the character index is out of range. - fn char_index_to_byte_offset(&self, char_index: usize) -> Option { - self.sql_text - .char_indices() - .nth(char_index) - .map(|(byte_offset, _)| byte_offset) - } - - /// Generate a summary string showing SQL fragment with error location. - /// (From SQLQueryContext.summary) - /// - /// Format example: - /// ```text - /// == SQL of VIEW v1 (line 1, position 8) == - /// SELECT a/b FROM t - /// ^^^ - /// ``` - pub fn format_summary(&self) -> String { - let start_char = self.start_index.max(0) as usize; - // stop_index is inclusive; fragment covers [start, stop] - let stop_char = (self.stop_index + 1).max(0) as usize; - - let fragment = match ( - self.char_index_to_byte_offset(start_char), - // stop_char may equal sql_text.chars().count() (one past the end) - self.char_index_to_byte_offset(stop_char).or_else(|| { - if stop_char == self.sql_text.chars().count() { - Some(self.sql_text.len()) - } else { - None - } - }), - ) { - (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], - _ => "", - }; - - // Build the header line - let mut summary = String::from("== SQL"); - - if let Some(obj_type) = &self.object_type { - if !obj_type.is_empty() { - summary.push_str(" of "); - summary.push_str(obj_type); - - if let Some(obj_name) = &self.object_name { - if !obj_name.is_empty() { - summary.push(' '); - summary.push_str(obj_name); - } - } - } - } - - summary.push_str(&format!( - " (line {}, position {}) ==\n", - self.line, - self.start_position + 1 // Convert 0-based to 1-based for display - )); - - // Add the SQL text with fragment highlighted - summary.push_str(&self.sql_text); - summary.push('\n'); - - // Add caret pointer - let caret_position = self.start_position.max(0) as usize; - summary.push_str(&" ".repeat(caret_position)); - // fragment.chars().count() gives the correct display width for non-ASCII - summary.push_str(&"^".repeat(fragment.chars().count().max(1))); - - summary - } - - /// Returns the SQL fragment that caused the error. - pub fn fragment(&self) -> String { - let start_char = self.start_index.max(0) as usize; - let stop_char = (self.stop_index + 1).max(0) as usize; - - match ( - self.char_index_to_byte_offset(start_char), - self.char_index_to_byte_offset(stop_char).or_else(|| { - if stop_char == self.sql_text.chars().count() { - Some(self.sql_text.len()) - } else { - None - } - }), - ) { - (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), - _ => String::new(), - } - } -} - -use std::collections::HashMap; -use std::sync::RwLock; - -/// Map that stores QueryContext information for expressions during execution. -/// -/// This map is populated during plan deserialization and accessed -/// during error creation to attach SQL context to exceptions. -#[derive(Debug)] -pub struct QueryContextMap { - /// Map from expression ID to QueryContext - contexts: RwLock>>, -} - -impl QueryContextMap { - pub fn new() -> Self { - Self { - contexts: RwLock::new(HashMap::new()), - } - } - - /// Register a QueryContext for an expression ID. - /// - /// If the expression ID already exists, it will be replaced. - /// - /// # Arguments - /// * `expr_id` - Unique expression identifier from protobuf - /// * `context` - QueryContext containing SQL text and position info - pub fn register(&self, expr_id: u64, context: QueryContext) { - let mut contexts = self.contexts.write().unwrap(); - contexts.insert(expr_id, Arc::new(context)); - } - - /// Get the QueryContext for an expression ID. - /// - /// Returns None if no context is registered for this expression. - /// - /// # Arguments - /// * `expr_id` - Expression identifier to look up - pub fn get(&self, expr_id: u64) -> Option> { - let contexts = self.contexts.read().unwrap(); - contexts.get(&expr_id).cloned() - } - - /// Clear all registered contexts. - /// - /// This is typically called after plan execution completes to free memory. - pub fn clear(&self) { - let mut contexts = self.contexts.write().unwrap(); - contexts.clear(); - } - - /// Return the number of registered contexts (for debugging/testing) - pub fn len(&self) -> usize { - let contexts = self.contexts.read().unwrap(); - contexts.len() - } - - /// Check if the map is empty - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl Default for QueryContextMap { - fn default() -> Self { - Self::new() - } -} - -/// Create a new session-scoped QueryContextMap. -/// -/// This should be called once per SessionContext during plan creation -/// and passed to expressions that need query context for error reporting. -pub fn create_query_context_map() -> Arc { - Arc::new(QueryContextMap::new()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_query_context_creation() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("Divide".to_string()), - Some("a/b".to_string()), - 1, - 7, - ); - - assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); - assert_eq!(ctx.start_index, 7); - assert_eq!(ctx.stop_index, 9); - assert_eq!(ctx.object_type, Some("Divide".to_string())); - assert_eq!(ctx.object_name, Some("a/b".to_string())); - assert_eq!(ctx.line, 1); - assert_eq!(ctx.start_position, 7); - } - - #[test] - fn test_query_context_serialization() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("Divide".to_string()), - Some("a/b".to_string()), - 1, - 7, - ); - - let json = serde_json::to_string(&ctx).unwrap(); - let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); - - assert_eq!(ctx, deserialized); - } - - #[test] - fn test_format_summary() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("VIEW".to_string()), - Some("v1".to_string()), - 1, - 7, - ); - - let summary = ctx.format_summary(); - - assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); - assert!(summary.contains("SELECT a/b FROM t")); - assert!(summary.contains("^^^")); // Three carets for "a/b" - } - - #[test] - fn test_format_summary_without_object() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let summary = ctx.format_summary(); - - assert!(summary.contains("== SQL (line 1, position 8) ==")); - assert!(summary.contains("SELECT a/b FROM t")); - } - - #[test] - fn test_fragment() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - assert_eq!(ctx.fragment(), "a/b"); - } - - #[test] - fn test_arc_string_sharing() { - let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let ctx2 = ctx1.clone(); - - // Arc should share the same allocation - assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); - } - - #[test] - fn test_json_with_optional_fields() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let json = serde_json::to_string(&ctx).unwrap(); - - // Should not serialize objectType and objectName when None - assert!(!json.contains("objectType")); - assert!(!json.contains("objectName")); - } - - #[test] - fn test_map_register_and_get() { - let map = QueryContextMap::new(); - - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - map.register(1, ctx.clone()); - - let retrieved = map.get(1).unwrap(); - assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); - assert_eq!(retrieved.start_index, 7); - } - - #[test] - fn test_map_get_nonexistent() { - let map = QueryContextMap::new(); - assert!(map.get(999).is_none()); - } - - #[test] - fn test_map_clear() { - let map = QueryContextMap::new(); - - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - map.register(1, ctx); - assert_eq!(map.len(), 1); - - map.clear(); - assert_eq!(map.len(), 0); - assert!(map.is_empty()); - } - - // Verify that fragment() and format_summary() correctly handle SQL text that - // contains multi-byte characters - - #[test] - fn test_fragment_non_ascii_accented() { - // "é" is a 2-byte UTF-8 sequence (U+00E9). - // SQL: "SELECT café FROM t" - // 0123456789... - // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. - // start_index=7, stop_index=10 should yield "café" - let sql = "SELECT café FROM t".to_string(); - let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); - assert_eq!(ctx.fragment(), "café"); - } -} +// Re-export QueryContext types from jvm-bridge crate +pub use datafusion_comet_jvm_bridge::query_context::{ + create_query_context_map, QueryContext, QueryContextMap, +}; From 0575a7502ab0c247b49708765a6fc1843af5bd8b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 09:37:10 -0600 Subject: [PATCH 02/11] refactor: move decimal_overflow_error back to spark-expr This is expression-level convenience logic, not JNI bridge infrastructure. --- native/jvm-bridge/src/lib.rs | 2 +- native/jvm-bridge/src/spark_error.rs | 19 ------------------- native/spark-expr/src/error.rs | 11 ++++++++++- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/native/jvm-bridge/src/lib.rs b/native/jvm-bridge/src/lib.rs index d6a92af243..2b37d61eec 100644 --- a/native/jvm-bridge/src/lib.rs +++ b/native/jvm-bridge/src/lib.rs @@ -31,7 +31,7 @@ pub mod query_context; pub mod spark_error; pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; -pub use spark_error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; +pub use spark_error::{SparkError, SparkErrorWithContext, SparkResult}; /// Global reference to the Java VM, initialized during native library setup. pub static JAVA_VM: OnceCell = OnceCell::new(); diff --git a/native/jvm-bridge/src/spark_error.rs b/native/jvm-bridge/src/spark_error.rs index ae3b5c0eda..fdde808496 100644 --- a/native/jvm-bridge/src/spark_error.rs +++ b/native/jvm-bridge/src/spark_error.rs @@ -565,25 +565,6 @@ impl SparkError { } } -/// Convert decimal overflow to SparkError::NumericValueOutOfRange. -/// -/// Creates the appropriate SparkError when a decimal value exceeds the precision limit for Decimal128 storage. -/// -/// # Arguments -/// * `value` - The i128 decimal value that overflowed -/// * `precision` - The target precision -/// * `scale` - The scale of the decimal -/// -/// # Returns -/// SparkError::NumericValueOutOfRange with the value, precision, and scale -pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { - SparkError::NumericValueOutOfRange { - value: value.to_string(), - precision, - scale, - } -} - pub type SparkResult = Result; /// Wrapper that adds QueryContext to SparkError diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs index c4c97c1b41..b162479def 100644 --- a/native/spark-expr/src/error.rs +++ b/native/spark-expr/src/error.rs @@ -17,5 +17,14 @@ // Re-export SparkError types from jvm-bridge crate pub use datafusion_comet_jvm_bridge::spark_error::{ - decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult, + SparkError, SparkErrorWithContext, SparkResult, }; + +/// Convert decimal overflow to SparkError::NumericValueOutOfRange. +pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { + SparkError::NumericValueOutOfRange { + value: value.to_string(), + precision, + scale, + } +} From 1ae0cd00311f71421fc8276c496b8d2103709608 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 09:40:33 -0600 Subject: [PATCH 03/11] refactor: tighten visibility in jvm-bridge crate - SparkError::exception_class() -> pub(crate) (only used in errors.rs) - SparkError::error_class() -> private (only used within spark_error.rs) - QueryContext::fragment() -> #[cfg(test)] private (only used in tests) - JVMClasses method ID fields -> private (only used within jvm_bridge mod) --- native/jvm-bridge/src/jvm_bridge/mod.rs | 8 ++++---- native/jvm-bridge/src/query_context.rs | 3 ++- native/jvm-bridge/src/spark_error.rs | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/native/jvm-bridge/src/jvm_bridge/mod.rs b/native/jvm-bridge/src/jvm_bridge/mod.rs index 19389ceeb5..c561197627 100644 --- a/native/jvm-bridge/src/jvm_bridge/mod.rs +++ b/native/jvm-bridge/src/jvm_bridge/mod.rs @@ -188,13 +188,13 @@ pub struct JVMClasses<'a> { /// Cached JClass for "java.lang.Throwable" java_lang_throwable: JClass<'a>, /// Cached method ID for "java.lang.Object#getClass" - pub object_get_class_method: JMethodID, + object_get_class_method: JMethodID, /// Cached method ID for "java.lang.Class#getName" - pub class_get_name_method: JMethodID, + class_get_name_method: JMethodID, /// Cached method ID for "java.lang.Throwable#getMessage" - pub throwable_get_message_method: JMethodID, + throwable_get_message_method: JMethodID, /// Cached method ID for "java.lang.Throwable#getCause" - pub throwable_get_cause_method: JMethodID, + throwable_get_cause_method: JMethodID, /// The CometMetricNode class. Used for updating the metrics. pub comet_metric_node: CometMetricNode<'a>, diff --git a/native/jvm-bridge/src/query_context.rs b/native/jvm-bridge/src/query_context.rs index e6591135e0..10ecef6550 100644 --- a/native/jvm-bridge/src/query_context.rs +++ b/native/jvm-bridge/src/query_context.rs @@ -155,7 +155,8 @@ impl QueryContext { } /// Returns the SQL fragment that caused the error. - pub fn fragment(&self) -> String { + #[cfg(test)] + fn fragment(&self) -> String { let start_char = self.start_index.max(0) as usize; let stop_char = (self.stop_index + 1).max(0) as usize; diff --git a/native/jvm-bridge/src/spark_error.rs b/native/jvm-bridge/src/spark_error.rs index fdde808496..aaec9ab530 100644 --- a/native/jvm-bridge/src/spark_error.rs +++ b/native/jvm-bridge/src/spark_error.rs @@ -437,7 +437,7 @@ impl SparkError { } /// Returns the appropriate Spark exception class for this error - pub fn exception_class(&self) -> &'static str { + pub(crate) fn exception_class(&self) -> &'static str { match self { // ArithmeticException SparkError::DivideByZero @@ -493,7 +493,7 @@ impl SparkError { } /// Returns the Spark error class code for this error - pub fn error_class(&self) -> Option<&'static str> { + fn error_class(&self) -> Option<&'static str> { match self { // Cast errors SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), From 9aab4c4bad41e7e640c10070775c0782c85b3fd6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 10:00:37 -0600 Subject: [PATCH 04/11] chore(deps): remove unused Rust dependencies Remove lazy_static, thiserror, regex from core; log from jvm-bridge; serde, thiserror from spark-expr. Update cargo-machete ignore list. --- native/Cargo.lock | 5 ----- native/core/Cargo.toml | 5 +---- native/jvm-bridge/Cargo.toml | 1 - native/spark-expr/Cargo.toml | 2 -- 4 files changed, 1 insertion(+), 12 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index 0432c07813..3d048d8efe 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1870,13 +1870,11 @@ dependencies = [ "procfs", "prost", "rand 0.10.0", - "regex", "reqwest", "serde_json", "simd-adler32", "snap", "tempfile", - "thiserror 2.0.18", "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", @@ -1909,7 +1907,6 @@ dependencies = [ "datafusion", "jni", "lazy_static", - "log", "once_cell", "parquet", "paste", @@ -1958,9 +1955,7 @@ dependencies = [ "num", "rand 0.10.0", "regex", - "serde", "serde_json", - "thiserror 2.0.18", "tokio", "twox-hash", ] diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index e7d3c96ae6..2780eb01d1 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -45,8 +45,6 @@ tokio = { version = "1", features = ["rt-multi-thread"] } async-trait = { workspace = true } log = "0.4" log4rs = "1.4.0" -thiserror = { workspace = true } -lazy_static = "1.4.0" prost = "0.14.3" jni = "0.21" snap = "1.1" @@ -64,7 +62,6 @@ datafusion-physical-expr-adapter = { workspace = true } datafusion-datasource = { workspace = true } datafusion-spark = { workspace = true } once_cell = "1.18.0" -regex = { workspace = true } crc32fast = "1.3.2" simd-adler32 = "0.3.7" datafusion-comet-spark-expr = { workspace = true } @@ -109,7 +106,7 @@ jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"] # exclude optional packages from cargo machete verifications [package.metadata.cargo-machete] -ignored = ["datafusion-comet-objectstore-hdfs", "hdfs-sys"] +ignored = ["hdfs-sys", "paste"] [lib] name = "comet" diff --git a/native/jvm-bridge/Cargo.toml b/native/jvm-bridge/Cargo.toml index 5870d89b17..bc2bd02add 100644 --- a/native/jvm-bridge/Cargo.toml +++ b/native/jvm-bridge/Cargo.toml @@ -40,7 +40,6 @@ serde_json = "1.0" lazy_static = "1.4.0" once_cell = "1.18.0" paste = "1.0.14" -log = "0.4" prost = "0.14.3" [dev-dependencies] diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 651b11feb1..179bacc1b9 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -34,9 +34,7 @@ datafusion-comet-jvm-bridge = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } -serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -thiserror = { workspace = true } futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } From 2c74084072b2e02283903299290cdd039ea11aa4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 16:00:15 -0600 Subject: [PATCH 05/11] refactor: move SparkError and QueryContext back to spark-expr SparkError and QueryContext are Spark expression types with no JNI dependency, but were placed in jvm-bridge during the crate extraction, forcing spark-expr to depend on jvm-bridge just to re-export its own error types. Move them back to spark-expr where they belong and introduce a pluggable ExternalErrorHandler callback in jvm-bridge so that core can register SparkError-specific JNI exception handling without jvm-bridge needing to know about SparkError. Dependency graph after this change: - core -> spark-expr (SparkError types) - core -> jvm-bridge (JNI bridge, CometError, JAVA_VM) - spark-expr has no jvm-bridge dependency - jvm-bridge has no spark-expr dependency --- native/Cargo.lock | 5 +- native/core/src/lib.rs | 30 + native/jvm-bridge/Cargo.toml | 4 +- native/jvm-bridge/src/errors.rs | 77 +-- native/jvm-bridge/src/lib.rs | 10 +- native/jvm-bridge/src/query_context.rs | 403 ------------ native/jvm-bridge/src/spark_error.rs | 850 ------------------------- native/spark-expr/Cargo.toml | 3 +- native/spark-expr/src/error.rs | 837 +++++++++++++++++++++++- native/spark-expr/src/lib.rs | 2 +- native/spark-expr/src/query_context.rs | 390 +++++++++++- 11 files changed, 1287 insertions(+), 1324 deletions(-) delete mode 100644 native/jvm-bridge/src/query_context.rs delete mode 100644 native/jvm-bridge/src/spark_error.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index 3d048d8efe..065d67f9c9 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1912,8 +1912,6 @@ dependencies = [ "paste", "prost", "regex", - "serde", - "serde_json", "thiserror 2.0.18", ] @@ -1949,13 +1947,14 @@ dependencies = [ "chrono-tz", "criterion", "datafusion", - "datafusion-comet-jvm-bridge", "futures", "hex", "num", "rand 0.10.0", "regex", + "serde", "serde_json", + "thiserror 2.0.18", "tokio", "twox-hash", ] diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 8ee3139ff9..c7473856ba 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -91,6 +91,9 @@ pub extern "system" fn Java_org_apache_comet_NativeBase_init( // Initialize the error handling to capture panic backtraces errors::init(); + // Register SparkError handler for JNI exception throwing + errors::register_external_error_handler(handle_spark_error); + try_unwrap_or_throw(&e, |mut env| { let path: String = env.get_string(&log_conf_path)?.into(); @@ -119,6 +122,33 @@ pub extern "system" fn Java_org_apache_comet_NativeBase_init( }) } +/// Handle SparkError variants in DataFusionError::External for JNI exception throwing. +/// Returns true if the error was handled. +fn handle_spark_error( + error: &(dyn std::error::Error + Send + Sync + 'static), + env: &mut JNIEnv, +) -> bool { + use datafusion_comet_spark_expr::{SparkError, SparkErrorWithContext}; + + if let Some(spark_error_with_ctx) = error.downcast_ref::() { + let json_message = spark_error_with_ctx.to_json(); + let _ = env.throw_new( + "org/apache/comet/exceptions/CometQueryExecutionException", + json_message, + ); + return true; + } + if let Some(spark_error) = error.downcast_ref::() { + let json_message = spark_error.to_json(); + let _ = env.throw_new( + "org/apache/comet/exceptions/CometQueryExecutionException", + json_message, + ); + return true; + } + false +} + const LOG_PATTERN: &str = "{d(%y/%m/%d %H:%M:%S)} {l} {f}: {m}{n}"; /// JNI method to check if a specific feature is enabled in the native Rust code. diff --git a/native/jvm-bridge/Cargo.toml b/native/jvm-bridge/Cargo.toml index bc2bd02add..07ddc3f883 100644 --- a/native/jvm-bridge/Cargo.toml +++ b/native/jvm-bridge/Cargo.toml @@ -21,7 +21,7 @@ version = { workspace = true } homepage = "https://datafusion.apache.org/comet" repository = "https://github.com/apache/datafusion-comet" authors = ["Apache DataFusion "] -description = "Apache DataFusion Comet: JVM bridge and error types" +description = "Apache DataFusion Comet: JVM bridge" readme = "README.md" license = "Apache-2.0" edition = "2021" @@ -35,8 +35,6 @@ datafusion = { workspace = true } jni = "0.21" thiserror = { workspace = true } regex = { workspace = true } -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" lazy_static = "1.4.0" once_cell = "1.18.0" paste = "1.0.14" diff --git a/native/jvm-bridge/src/errors.rs b/native/jvm-bridge/src/errors.rs index 6e8c818430..a71ee9105a 100644 --- a/native/jvm-bridge/src/errors.rs +++ b/native/jvm-bridge/src/errors.rs @@ -37,13 +37,28 @@ use std::{ // lifetime checker won't let us. use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, jshort}; -use crate::spark_error::{SparkError, SparkErrorWithContext}; use jni::objects::{GlobalRef, JThrowable}; use jni::JNIEnv; use lazy_static::lazy_static; +use once_cell::sync::OnceCell; use parquet::errors::ParquetError; use thiserror::Error; +/// Handler for DataFusionError::External errors during JNI exception throwing. +/// Returns true if the error was handled (exception thrown), false to fall through +/// to the default handler. +pub type ExternalErrorHandler = + fn(error: &(dyn std::error::Error + Send + Sync + 'static), env: &mut JNIEnv) -> bool; + +static EXTERNAL_ERROR_HANDLER: OnceCell = OnceCell::new(); + +/// Register a handler for DataFusionError::External errors. +/// This allows the core crate to register SparkError-specific handling +/// without the jvm-bridge crate needing to know about SparkError. +pub fn register_external_error_handler(handler: ExternalErrorHandler) { + let _ = EXTERNAL_ERROR_HANDLER.set(handler); +} + lazy_static! { static ref PANIC_BACKTRACE: Arc>> = Arc::new(Mutex::new(None)); } @@ -90,11 +105,6 @@ pub enum CometError { #[error("Comet Internal Error: {0}")] Internal(String), - /// CometError::Spark is typically used in native code to emulate the same errors - /// that Spark would return - #[error(transparent)] - Spark(SparkError), - #[error(transparent)] Arrow { #[from] @@ -313,10 +323,6 @@ impl jni::errors::ToException for CometError { class: "java/lang/NullPointerException".to_string(), msg: self.to_string(), }, - CometError::Spark(spark_err) => Exception { - class: spark_err.exception_class().to_string(), - msg: spark_err.to_string(), - }, CometError::NumberIntFormat { source: s } => Exception { class: "java/lang/NumberFormatException".to_string(), msg: s.to_string(), @@ -483,35 +489,27 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw(<&JThrowable>::from(throwable.as_obj())), - // Handle DataFusion errors containing SparkError - serialize to JSON + // Handle DataFusion errors containing external errors (e.g. SparkError) CometError::DataFusion { msg: _, source: DataFusionError::External(e), } => { - // Try SparkErrorWithContext first (includes context) - if let Some(spark_error_with_ctx) = e.downcast_ref::() { - let json_message = spark_error_with_ctx.to_json(); - env.throw_new( - "org/apache/comet/exceptions/CometQueryExecutionException", - json_message, - ) - } else if let Some(spark_error) = e.downcast_ref::() { - // Fall back to plain SparkError (no context) - throw_spark_error_as_json(env, spark_error) - } else { - // Not a SparkError, use generic exception - let exception = error.to_exception(); - match backtrace { - Some(backtrace_string) => env.throw_new( - exception.class, - to_stacktrace_string(exception.msg, backtrace_string).unwrap(), - ), - _ => env.throw_new(exception.class, exception.msg), + // Try registered handler first (e.g. SparkError handling from core) + if let Some(handler) = EXTERNAL_ERROR_HANDLER.get() { + if handler(e.as_ref(), env) { + return; } } + // Fall through to generic exception + let exception = error.to_exception(); + match backtrace { + Some(backtrace_string) => env.throw_new( + exception.class, + to_stacktrace_string(exception.msg, backtrace_string).unwrap(), + ), + _ => env.throw_new(exception.class, exception.msg), + } } - // Handle direct SparkError - serialize to JSON - CometError::Spark(spark_error) => throw_spark_error_as_json(env, spark_error), _ => { let exception = error.to_exception(); match backtrace { @@ -527,21 +525,6 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option jni::errors::Result<()> { - // Serialize error to JSON - let json_message = spark_error.to_json(); - - // Throw CometQueryExecutionException with JSON message - env.throw_new( - "org/apache/comet/exceptions/CometQueryExecutionException", - json_message, - ) -} - #[derive(Debug, Error)] enum StacktraceError { #[error("Unable to initialize message: {0}")] diff --git a/native/jvm-bridge/src/lib.rs b/native/jvm-bridge/src/lib.rs index 2b37d61eec..19bfb39cfa 100644 --- a/native/jvm-bridge/src/lib.rs +++ b/native/jvm-bridge/src/lib.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! JVM bridge and error types for Apache DataFusion Comet. +//! JVM bridge for Apache DataFusion Comet. //! -//! This crate provides the JNI/JVM interaction layer and common error types -//! used across Comet's native Rust crates. +//! This crate provides the JNI/JVM interaction layer used across Comet's native Rust crates. #![allow(clippy::result_large_err)] @@ -27,11 +26,6 @@ use once_cell::sync::OnceCell; pub mod errors; pub mod jvm_bridge; -pub mod query_context; -pub mod spark_error; - -pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; -pub use spark_error::{SparkError, SparkErrorWithContext, SparkResult}; /// Global reference to the Java VM, initialized during native library setup. pub static JAVA_VM: OnceCell = OnceCell::new(); diff --git a/native/jvm-bridge/src/query_context.rs b/native/jvm-bridge/src/query_context.rs deleted file mode 100644 index 10ecef6550..0000000000 --- a/native/jvm-bridge/src/query_context.rs +++ /dev/null @@ -1,403 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Query execution context for error reporting -//! -//! This module provides QueryContext which mirrors Spark's SQLQueryContext -//! for providing SQL text, line/position information, and error location -//! pointers in exception messages. - -use serde::{Deserialize, Serialize}; -use std::sync::Arc; - -/// Based on Spark's SQLQueryContext for error reporting. -/// -/// Contains information about where an error occurred in a SQL query, -/// including the full SQL text, line/column positions, and object context. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct QueryContext { - /// Full SQL query text - #[serde(rename = "sqlText")] - pub sql_text: Arc, - - /// Start offset in SQL text (0-based, character index) - #[serde(rename = "startIndex")] - pub start_index: i32, - - /// Stop offset in SQL text (0-based, character index, inclusive) - #[serde(rename = "stopIndex")] - pub stop_index: i32, - - /// Object type (e.g., "VIEW", "Project", "Filter") - #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] - pub object_type: Option, - - /// Object name (e.g., view name, column name) - #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] - pub object_name: Option, - - /// Line number in SQL query (1-based) - pub line: i32, - - /// Column position within the line (0-based) - #[serde(rename = "startPosition")] - pub start_position: i32, -} - -impl QueryContext { - #[allow(clippy::too_many_arguments)] - pub fn new( - sql_text: String, - start_index: i32, - stop_index: i32, - object_type: Option, - object_name: Option, - line: i32, - start_position: i32, - ) -> Self { - Self { - sql_text: Arc::new(sql_text), - start_index, - stop_index, - object_type, - object_name, - line, - start_position, - } - } - - /// Convert a character index to a byte offset in the SQL text. - /// Returns None if the character index is out of range. - fn char_index_to_byte_offset(&self, char_index: usize) -> Option { - self.sql_text - .char_indices() - .nth(char_index) - .map(|(byte_offset, _)| byte_offset) - } - - /// Generate a summary string showing SQL fragment with error location. - /// (From SQLQueryContext.summary) - /// - /// Format example: - /// ```text - /// == SQL of VIEW v1 (line 1, position 8) == - /// SELECT a/b FROM t - /// ^^^ - /// ``` - pub fn format_summary(&self) -> String { - let start_char = self.start_index.max(0) as usize; - // stop_index is inclusive; fragment covers [start, stop] - let stop_char = (self.stop_index + 1).max(0) as usize; - - let fragment = match ( - self.char_index_to_byte_offset(start_char), - // stop_char may equal sql_text.chars().count() (one past the end) - self.char_index_to_byte_offset(stop_char).or_else(|| { - if stop_char == self.sql_text.chars().count() { - Some(self.sql_text.len()) - } else { - None - } - }), - ) { - (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], - _ => "", - }; - - // Build the header line - let mut summary = String::from("== SQL"); - - if let Some(obj_type) = &self.object_type { - if !obj_type.is_empty() { - summary.push_str(" of "); - summary.push_str(obj_type); - - if let Some(obj_name) = &self.object_name { - if !obj_name.is_empty() { - summary.push(' '); - summary.push_str(obj_name); - } - } - } - } - - summary.push_str(&format!( - " (line {}, position {}) ==\n", - self.line, - self.start_position + 1 // Convert 0-based to 1-based for display - )); - - // Add the SQL text with fragment highlighted - summary.push_str(&self.sql_text); - summary.push('\n'); - - // Add caret pointer - let caret_position = self.start_position.max(0) as usize; - summary.push_str(&" ".repeat(caret_position)); - // fragment.chars().count() gives the correct display width for non-ASCII - summary.push_str(&"^".repeat(fragment.chars().count().max(1))); - - summary - } - - /// Returns the SQL fragment that caused the error. - #[cfg(test)] - fn fragment(&self) -> String { - let start_char = self.start_index.max(0) as usize; - let stop_char = (self.stop_index + 1).max(0) as usize; - - match ( - self.char_index_to_byte_offset(start_char), - self.char_index_to_byte_offset(stop_char).or_else(|| { - if stop_char == self.sql_text.chars().count() { - Some(self.sql_text.len()) - } else { - None - } - }), - ) { - (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), - _ => String::new(), - } - } -} - -use std::collections::HashMap; -use std::sync::RwLock; - -/// Map that stores QueryContext information for expressions during execution. -/// -/// This map is populated during plan deserialization and accessed -/// during error creation to attach SQL context to exceptions. -#[derive(Debug)] -pub struct QueryContextMap { - /// Map from expression ID to QueryContext - contexts: RwLock>>, -} - -impl QueryContextMap { - pub fn new() -> Self { - Self { - contexts: RwLock::new(HashMap::new()), - } - } - - /// Register a QueryContext for an expression ID. - /// - /// If the expression ID already exists, it will be replaced. - /// - /// # Arguments - /// * `expr_id` - Unique expression identifier from protobuf - /// * `context` - QueryContext containing SQL text and position info - pub fn register(&self, expr_id: u64, context: QueryContext) { - let mut contexts = self.contexts.write().unwrap(); - contexts.insert(expr_id, Arc::new(context)); - } - - /// Get the QueryContext for an expression ID. - /// - /// Returns None if no context is registered for this expression. - /// - /// # Arguments - /// * `expr_id` - Expression identifier to look up - pub fn get(&self, expr_id: u64) -> Option> { - let contexts = self.contexts.read().unwrap(); - contexts.get(&expr_id).cloned() - } - - /// Clear all registered contexts. - /// - /// This is typically called after plan execution completes to free memory. - pub fn clear(&self) { - let mut contexts = self.contexts.write().unwrap(); - contexts.clear(); - } - - /// Return the number of registered contexts (for debugging/testing) - pub fn len(&self) -> usize { - let contexts = self.contexts.read().unwrap(); - contexts.len() - } - - /// Check if the map is empty - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl Default for QueryContextMap { - fn default() -> Self { - Self::new() - } -} - -/// Create a new session-scoped QueryContextMap. -/// -/// This should be called once per SessionContext during plan creation -/// and passed to expressions that need query context for error reporting. -pub fn create_query_context_map() -> Arc { - Arc::new(QueryContextMap::new()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_query_context_creation() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("Divide".to_string()), - Some("a/b".to_string()), - 1, - 7, - ); - - assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); - assert_eq!(ctx.start_index, 7); - assert_eq!(ctx.stop_index, 9); - assert_eq!(ctx.object_type, Some("Divide".to_string())); - assert_eq!(ctx.object_name, Some("a/b".to_string())); - assert_eq!(ctx.line, 1); - assert_eq!(ctx.start_position, 7); - } - - #[test] - fn test_query_context_serialization() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("Divide".to_string()), - Some("a/b".to_string()), - 1, - 7, - ); - - let json = serde_json::to_string(&ctx).unwrap(); - let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); - - assert_eq!(ctx, deserialized); - } - - #[test] - fn test_format_summary() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("VIEW".to_string()), - Some("v1".to_string()), - 1, - 7, - ); - - let summary = ctx.format_summary(); - - assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); - assert!(summary.contains("SELECT a/b FROM t")); - assert!(summary.contains("^^^")); // Three carets for "a/b" - } - - #[test] - fn test_format_summary_without_object() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let summary = ctx.format_summary(); - - assert!(summary.contains("== SQL (line 1, position 8) ==")); - assert!(summary.contains("SELECT a/b FROM t")); - } - - #[test] - fn test_fragment() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - assert_eq!(ctx.fragment(), "a/b"); - } - - #[test] - fn test_arc_string_sharing() { - let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let ctx2 = ctx1.clone(); - - // Arc should share the same allocation - assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); - } - - #[test] - fn test_json_with_optional_fields() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let json = serde_json::to_string(&ctx).unwrap(); - - // Should not serialize objectType and objectName when None - assert!(!json.contains("objectType")); - assert!(!json.contains("objectName")); - } - - #[test] - fn test_map_register_and_get() { - let map = QueryContextMap::new(); - - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - map.register(1, ctx.clone()); - - let retrieved = map.get(1).unwrap(); - assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); - assert_eq!(retrieved.start_index, 7); - } - - #[test] - fn test_map_get_nonexistent() { - let map = QueryContextMap::new(); - assert!(map.get(999).is_none()); - } - - #[test] - fn test_map_clear() { - let map = QueryContextMap::new(); - - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - map.register(1, ctx); - assert_eq!(map.len(), 1); - - map.clear(); - assert_eq!(map.len(), 0); - assert!(map.is_empty()); - } - - // Verify that fragment() and format_summary() correctly handle SQL text that - // contains multi-byte characters - - #[test] - fn test_fragment_non_ascii_accented() { - // "é" is a 2-byte UTF-8 sequence (U+00E9). - // SQL: "SELECT café FROM t" - // 0123456789... - // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. - // start_index=7, stop_index=10 should yield "café" - let sql = "SELECT café FROM t".to_string(); - let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); - assert_eq!(ctx.fragment(), "café"); - } -} diff --git a/native/jvm-bridge/src/spark_error.rs b/native/jvm-bridge/src/spark_error.rs deleted file mode 100644 index aaec9ab530..0000000000 --- a/native/jvm-bridge/src/spark_error.rs +++ /dev/null @@ -1,850 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::error::ArrowError; -use datafusion::common::DataFusionError; -use std::sync::Arc; - -#[derive(thiserror::Error, Debug, Clone)] -pub enum SparkError { - // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet - #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - because it is malformed. Correct the value as per the syntax, or change its target type. \ - Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastInvalidValue { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] - NumericValueOutOfRange { - value: String, - precision: u8, - scale: i8, - }, - - #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")] - NumericOutOfRange { value: String }, - - #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastOverFlow { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] - CannotParseDecimal, - - #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - ArithmeticOverflow { from_type: String }, - - #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntegralDivideOverflow, - - #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - DecimalSumOverflow, - - #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - DivideByZero, - - #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - RemainderByZero, - - #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] - IntervalDividedByZero, - - #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - BinaryArithmeticOverflow { - value1: String, - symbol: String, - value2: String, - function_name: String, - }, - - #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntervalArithmeticOverflowWithSuggestion { function_name: String }, - - #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntervalArithmeticOverflowWithoutSuggestion, - - #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] - DatetimeOverflow, - - #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidArrayIndex { index_value: i32, array_size: i32 }, - - #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidElementAtIndex { index_value: i32, array_size: i32 }, - - #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] - InvalidBitmapPosition { - bit_position: i64, - bitmap_num_bytes: i64, - bitmap_num_bits: i64, - }, - - #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] - InvalidIndexOfZero, - - #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] - DuplicatedMapKey { key: String }, - - #[error("[NULL_MAP_KEY] Cannot use null as map key.")] - NullMapKey, - - #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] - MapKeyValueDiffSizes, - - #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] - ExceedMapSizeLimit { size: i32, max_size: i32 }, - - #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] - CollectionSizeLimitExceeded { - num_elements: i64, - max_elements: i64, - }, - - #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] - NotNullAssertViolation { field_name: String }, - - #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] - ValueIsNull { field_name: String, row_index: i32 }, - - #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] - CannotParseTimestamp { - message: String, - suggested_func: String, - }, - - #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidFractionOfSecond { value: f64 }, - - #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] - InvalidUtf8String { hex_string: String }, - - #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] - UnexpectedPositiveValue { - parameter_name: String, - actual_value: i32, - }, - - #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] - UnexpectedNegativeValue { - parameter_name: String, - actual_value: i32, - }, - - #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] - InvalidRegexGroupIndex { - function_name: String, - group_count: i32, - group_index: i32, - }, - - #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] - DatatypeCannotOrder { data_type: String }, - - #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] - ScalarSubqueryTooManyRows, - - #[error("ArrowError: {0}.")] - Arrow(Arc), - - #[error("InternalError: {0}.")] - Internal(String), -} - -impl SparkError { - /// Serialize this error to JSON format for JNI transfer - pub fn to_json(&self) -> String { - let error_class = self.error_class().unwrap_or(""); - - // Create a JSON structure with errorType, errorClass, and params - match serde_json::to_string(&serde_json::json!({ - "errorType": self.error_type_name(), - "errorClass": error_class, - "params": self.params_as_json(), - })) { - Ok(json) => json, - Err(e) => { - // Fallback if serialization fails - format!( - "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", - e - ) - } - } - } - - /// Get the error type name for JSON serialization - fn error_type_name(&self) -> &'static str { - match self { - SparkError::CastInvalidValue { .. } => "CastInvalidValue", - SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", - SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", - SparkError::CastOverFlow { .. } => "CastOverFlow", - SparkError::CannotParseDecimal => "CannotParseDecimal", - SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", - SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", - SparkError::DecimalSumOverflow => "DecimalSumOverflow", - SparkError::DivideByZero => "DivideByZero", - SparkError::RemainderByZero => "RemainderByZero", - SparkError::IntervalDividedByZero => "IntervalDividedByZero", - SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", - SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { - "IntervalArithmeticOverflowWithSuggestion" - } - SparkError::IntervalArithmeticOverflowWithoutSuggestion => { - "IntervalArithmeticOverflowWithoutSuggestion" - } - SparkError::DatetimeOverflow => "DatetimeOverflow", - SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", - SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", - SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", - SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", - SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", - SparkError::NullMapKey => "NullMapKey", - SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", - SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", - SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", - SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", - SparkError::ValueIsNull { .. } => "ValueIsNull", - SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", - SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", - SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", - SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", - SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", - SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", - SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", - SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", - SparkError::Arrow(_) => "Arrow", - SparkError::Internal(_) => "Internal", - } - } - - /// Extract parameters as JSON value - fn params_as_json(&self) -> serde_json::Value { - match self { - SparkError::CastInvalidValue { - value, - from_type, - to_type, - } => { - serde_json::json!({ - "value": value, - "fromType": from_type, - "toType": to_type, - }) - } - SparkError::NumericValueOutOfRange { - value, - precision, - scale, - } => { - serde_json::json!({ - "value": value, - "precision": precision, - "scale": scale, - }) - } - SparkError::NumericOutOfRange { value } => { - serde_json::json!({ - "value": value, - }) - } - SparkError::CastOverFlow { - value, - from_type, - to_type, - } => { - serde_json::json!({ - "value": value, - "fromType": from_type, - "toType": to_type, - }) - } - SparkError::ArithmeticOverflow { from_type } => { - serde_json::json!({ - "fromType": from_type, - }) - } - SparkError::BinaryArithmeticOverflow { - value1, - symbol, - value2, - function_name, - } => { - serde_json::json!({ - "value1": value1, - "symbol": symbol, - "value2": value2, - "functionName": function_name, - }) - } - SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { - serde_json::json!({ - "functionName": function_name, - }) - } - SparkError::InvalidArrayIndex { - index_value, - array_size, - } => { - serde_json::json!({ - "indexValue": index_value, - "arraySize": array_size, - }) - } - SparkError::InvalidElementAtIndex { - index_value, - array_size, - } => { - serde_json::json!({ - "indexValue": index_value, - "arraySize": array_size, - }) - } - SparkError::InvalidBitmapPosition { - bit_position, - bitmap_num_bytes, - bitmap_num_bits, - } => { - serde_json::json!({ - "bitPosition": bit_position, - "bitmapNumBytes": bitmap_num_bytes, - "bitmapNumBits": bitmap_num_bits, - }) - } - SparkError::DuplicatedMapKey { key } => { - serde_json::json!({ - "key": key, - }) - } - SparkError::ExceedMapSizeLimit { size, max_size } => { - serde_json::json!({ - "size": size, - "maxSize": max_size, - }) - } - SparkError::CollectionSizeLimitExceeded { - num_elements, - max_elements, - } => { - serde_json::json!({ - "numElements": num_elements, - "maxElements": max_elements, - }) - } - SparkError::NotNullAssertViolation { field_name } => { - serde_json::json!({ - "fieldName": field_name, - }) - } - SparkError::ValueIsNull { - field_name, - row_index, - } => { - serde_json::json!({ - "fieldName": field_name, - "rowIndex": row_index, - }) - } - SparkError::CannotParseTimestamp { - message, - suggested_func, - } => { - serde_json::json!({ - "message": message, - "suggestedFunc": suggested_func, - }) - } - SparkError::InvalidFractionOfSecond { value } => { - serde_json::json!({ - "value": value, - }) - } - SparkError::InvalidUtf8String { hex_string } => { - serde_json::json!({ - "hexString": hex_string, - }) - } - SparkError::UnexpectedPositiveValue { - parameter_name, - actual_value, - } => { - serde_json::json!({ - "parameterName": parameter_name, - "actualValue": actual_value, - }) - } - SparkError::UnexpectedNegativeValue { - parameter_name, - actual_value, - } => { - serde_json::json!({ - "parameterName": parameter_name, - "actualValue": actual_value, - }) - } - SparkError::InvalidRegexGroupIndex { - function_name, - group_count, - group_index, - } => { - serde_json::json!({ - "functionName": function_name, - "groupCount": group_count, - "groupIndex": group_index, - }) - } - SparkError::DatatypeCannotOrder { data_type } => { - serde_json::json!({ - "dataType": data_type, - }) - } - SparkError::Arrow(e) => { - serde_json::json!({ - "message": e.to_string(), - }) - } - SparkError::Internal(msg) => { - serde_json::json!({ - "message": msg, - }) - } - // Simple errors with no parameters - _ => serde_json::json!({}), - } - } - - /// Returns the appropriate Spark exception class for this error - pub(crate) fn exception_class(&self) -> &'static str { - match self { - // ArithmeticException - SparkError::DivideByZero - | SparkError::RemainderByZero - | SparkError::IntervalDividedByZero - | SparkError::NumericValueOutOfRange { .. } - | SparkError::NumericOutOfRange { .. } // Comet-specific extension - | SparkError::ArithmeticOverflow { .. } - | SparkError::IntegralDivideOverflow - | SparkError::DecimalSumOverflow - | SparkError::BinaryArithmeticOverflow { .. } - | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } - | SparkError::IntervalArithmeticOverflowWithoutSuggestion - | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", - - // CastOverflow gets special handling with CastOverflowException - SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", - - // NumberFormatException (for cast invalid input errors) - SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", - - // ArrayIndexOutOfBoundsException - SparkError::InvalidArrayIndex { .. } - | SparkError::InvalidElementAtIndex { .. } - | SparkError::InvalidBitmapPosition { .. } - | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", - - // RuntimeException - SparkError::CannotParseDecimal - | SparkError::DuplicatedMapKey { .. } - | SparkError::NullMapKey - | SparkError::MapKeyValueDiffSizes - | SparkError::ExceedMapSizeLimit { .. } - | SparkError::CollectionSizeLimitExceeded { .. } - | SparkError::NotNullAssertViolation { .. } - | SparkError::ValueIsNull { .. } // Comet-specific extension - | SparkError::UnexpectedPositiveValue { .. } - | SparkError::UnexpectedNegativeValue { .. } - | SparkError::InvalidRegexGroupIndex { .. } - | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", - - // DateTimeException - SparkError::CannotParseTimestamp { .. } - | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", - - // IllegalArgumentException - SparkError::DatatypeCannotOrder { .. } - | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", - - // Generic errors - SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", - } - } - - /// Returns the Spark error class code for this error - fn error_class(&self) -> Option<&'static str> { - match self { - // Cast errors - SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), - SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), - SparkError::NumericValueOutOfRange { .. } => { - Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") - } - SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), - SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), - - // Arithmetic errors - SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), - SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), - SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), - SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), - SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), - SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), - SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), - SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { - Some("INTERVAL_ARITHMETIC_OVERFLOW") - } - SparkError::IntervalArithmeticOverflowWithoutSuggestion => { - Some("INTERVAL_ARITHMETIC_OVERFLOW") - } - SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), - - // Array index errors - SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), - SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), - SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), - SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), - - // Map/Collection errors - SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), - SparkError::NullMapKey => Some("NULL_MAP_KEY"), - SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), - SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), - SparkError::CollectionSizeLimitExceeded { .. } => { - Some("COLLECTION_SIZE_LIMIT_EXCEEDED") - } - - // Null validation errors - SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), - SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), - - // DateTime errors - SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), - SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), - - // String/UTF8 errors - SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), - - // Function parameter errors - SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), - SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), - - // Regex errors - SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), - - // Unsupported operation errors - SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), - - // Subquery errors - SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), - - // Generic errors (no error class) - SparkError::Arrow(_) | SparkError::Internal(_) => None, - } - } -} - -pub type SparkResult = Result; - -/// Wrapper that adds QueryContext to SparkError -/// -/// This allows attaching SQL context information (query text, line/position, object name) to errors -#[derive(Debug, Clone)] -pub struct SparkErrorWithContext { - /// The underlying SparkError - pub error: SparkError, - /// Optional QueryContext for SQL location information - pub context: Option>, -} - -impl SparkErrorWithContext { - /// Create a SparkErrorWithContext without context - pub fn new(error: SparkError) -> Self { - Self { - error, - context: None, - } - } - - /// Create a SparkErrorWithContext with QueryContext - pub fn with_context(error: SparkError, context: Arc) -> Self { - Self { - error, - context: Some(context), - } - } - - /// Serialize to JSON including optional context field - /// - /// JSON structure: - /// ```json - /// { - /// "errorType": "DivideByZero", - /// "errorClass": "DIVIDE_BY_ZERO", - /// "params": {}, - /// "context": { - /// "sqlText": "SELECT a/b FROM t", - /// "startIndex": 7, - /// "stopIndex": 9, - /// "line": 1, - /// "startPosition": 7 - /// }, - /// "summary": "== SQL (line 1, position 8) ==\n..." - /// } - /// ``` - pub fn to_json(&self) -> String { - let mut json_obj = serde_json::json!({ - "errorType": self.error.error_type_name(), - "errorClass": self.error.error_class().unwrap_or(""), - "params": self.error.params_as_json(), - }); - - if let Some(ctx) = &self.context { - // Serialize context fields - json_obj["context"] = serde_json::json!({ - "sqlText": ctx.sql_text.as_str(), - "startIndex": ctx.start_index, - "stopIndex": ctx.stop_index, - "objectType": ctx.object_type, - "objectName": ctx.object_name, - "line": ctx.line, - "startPosition": ctx.start_position, - }); - - // Add formatted summary - json_obj["summary"] = serde_json::json!(ctx.format_summary()); - } - - serde_json::to_string(&json_obj).unwrap_or_else(|e| { - format!( - "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", - e - ) - }) - } -} - -impl std::fmt::Display for SparkErrorWithContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.error)?; - if let Some(ctx) = &self.context { - write!(f, "\n{}", ctx.format_summary())?; - } - Ok(()) - } -} - -impl std::error::Error for SparkErrorWithContext {} - -impl From for SparkErrorWithContext { - fn from(error: SparkError) -> Self { - SparkErrorWithContext::new(error) - } -} - -impl From for DataFusionError { - fn from(value: SparkErrorWithContext) -> Self { - DataFusionError::External(Box::new(value)) - } -} - -impl From for SparkError { - fn from(value: ArrowError) -> Self { - SparkError::Arrow(Arc::new(value)) - } -} - -impl From for DataFusionError { - fn from(value: SparkError) -> Self { - DataFusionError::External(Box::new(value)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_divide_by_zero_json() { - let error = SparkError::DivideByZero; - let json = error.to_json(); - - assert!(json.contains("\"errorType\":\"DivideByZero\"")); - assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); - - // Verify it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "DivideByZero"); - assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); - } - - #[test] - fn test_remainder_by_zero_json() { - let error = SparkError::RemainderByZero; - let json = error.to_json(); - - assert!(json.contains("\"errorType\":\"RemainderByZero\"")); - assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); - } - - #[test] - fn test_binary_overflow_json() { - let error = SparkError::BinaryArithmeticOverflow { - value1: "32767".to_string(), - symbol: "+".to_string(), - value2: "1".to_string(), - function_name: "try_add".to_string(), - }; - let json = error.to_json(); - - // Verify it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); - assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); - assert_eq!(parsed["params"]["value1"], "32767"); - assert_eq!(parsed["params"]["symbol"], "+"); - assert_eq!(parsed["params"]["value2"], "1"); - assert_eq!(parsed["params"]["functionName"], "try_add"); - } - - #[test] - fn test_invalid_array_index_json() { - let error = SparkError::InvalidArrayIndex { - index_value: 10, - array_size: 3, - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "InvalidArrayIndex"); - assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); - assert_eq!(parsed["params"]["indexValue"], 10); - assert_eq!(parsed["params"]["arraySize"], 3); - } - - #[test] - fn test_numeric_value_out_of_range_json() { - let error = SparkError::NumericValueOutOfRange { - value: "999.99".to_string(), - precision: 5, - scale: 2, - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); - assert_eq!( - parsed["errorClass"], - "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" - ); - assert_eq!(parsed["params"]["value"], "999.99"); - assert_eq!(parsed["params"]["precision"], 5); - assert_eq!(parsed["params"]["scale"], 2); - } - - #[test] - fn test_cast_invalid_value_json() { - let error = SparkError::CastInvalidValue { - value: "abc".to_string(), - from_type: "STRING".to_string(), - to_type: "INT".to_string(), - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "CastInvalidValue"); - assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); - assert_eq!(parsed["params"]["value"], "abc"); - assert_eq!(parsed["params"]["fromType"], "STRING"); - assert_eq!(parsed["params"]["toType"], "INT"); - } - - #[test] - fn test_duplicated_map_key_json() { - let error = SparkError::DuplicatedMapKey { - key: "duplicate_key".to_string(), - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "DuplicatedMapKey"); - assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); - assert_eq!(parsed["params"]["key"], "duplicate_key"); - } - - #[test] - fn test_null_map_key_json() { - let error = SparkError::NullMapKey; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "NullMapKey"); - assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); - // Params should be an empty object - assert_eq!(parsed["params"], serde_json::json!({})); - } - - #[test] - fn test_error_class_mapping() { - // Test that error_class() returns the correct error class - assert_eq!( - SparkError::DivideByZero.error_class(), - Some("DIVIDE_BY_ZERO") - ); - assert_eq!( - SparkError::RemainderByZero.error_class(), - Some("REMAINDER_BY_ZERO") - ); - assert_eq!( - SparkError::InvalidArrayIndex { - index_value: 0, - array_size: 0 - } - .error_class(), - Some("INVALID_ARRAY_INDEX") - ); - assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); - } - - #[test] - fn test_exception_class_mapping() { - // Test that exception_class() returns the correct Java exception class - assert_eq!( - SparkError::DivideByZero.exception_class(), - "org/apache/spark/SparkArithmeticException" - ); - assert_eq!( - SparkError::InvalidArrayIndex { - index_value: 0, - array_size: 0 - } - .exception_class(), - "org/apache/spark/SparkArrayIndexOutOfBoundsException" - ); - assert_eq!( - SparkError::NullMapKey.exception_class(), - "org/apache/spark/SparkRuntimeException" - ); - } -} diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 179bacc1b9..9f08e480f2 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -30,11 +30,12 @@ edition = { workspace = true } arrow = { workspace = true } chrono = { workspace = true } datafusion = { workspace = true } -datafusion-comet-jvm-bridge = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } +serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +thiserror = { workspace = true } futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs index b162479def..bd7efe1638 100644 --- a/native/spark-expr/src/error.rs +++ b/native/spark-expr/src/error.rs @@ -15,10 +15,557 @@ // specific language governing permissions and limitations // under the License. -// Re-export SparkError types from jvm-bridge crate -pub use datafusion_comet_jvm_bridge::spark_error::{ - SparkError, SparkErrorWithContext, SparkResult, -}; +use arrow::error::ArrowError; +use datafusion::common::DataFusionError; +use std::sync::Arc; + +#[derive(thiserror::Error, Debug, Clone)] +pub enum SparkError { + // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastInvalidValue { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + + #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")] + NumericOutOfRange { value: String }, + + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastOverFlow { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] + CannotParseDecimal, + + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntegralDivideOverflow, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DecimalSumOverflow, + + #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DivideByZero, + + #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + RemainderByZero, + + #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] + IntervalDividedByZero, + + #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + BinaryArithmeticOverflow { + value1: String, + symbol: String, + value2: String, + function_name: String, + }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithSuggestion { function_name: String }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithoutSuggestion, + + #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] + DatetimeOverflow, + + #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidArrayIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidElementAtIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] + InvalidBitmapPosition { + bit_position: i64, + bitmap_num_bytes: i64, + bitmap_num_bits: i64, + }, + + #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] + InvalidIndexOfZero, + + #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] + DuplicatedMapKey { key: String }, + + #[error("[NULL_MAP_KEY] Cannot use null as map key.")] + NullMapKey, + + #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] + MapKeyValueDiffSizes, + + #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] + ExceedMapSizeLimit { size: i32, max_size: i32 }, + + #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] + CollectionSizeLimitExceeded { + num_elements: i64, + max_elements: i64, + }, + + #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] + NotNullAssertViolation { field_name: String }, + + #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] + ValueIsNull { field_name: String, row_index: i32 }, + + #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] + CannotParseTimestamp { + message: String, + suggested_func: String, + }, + + #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidFractionOfSecond { value: f64 }, + + #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] + InvalidUtf8String { hex_string: String }, + + #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] + UnexpectedPositiveValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] + UnexpectedNegativeValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] + InvalidRegexGroupIndex { + function_name: String, + group_count: i32, + group_index: i32, + }, + + #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] + DatatypeCannotOrder { data_type: String }, + + #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] + ScalarSubqueryTooManyRows, + + #[error("ArrowError: {0}.")] + Arrow(Arc), + + #[error("InternalError: {0}.")] + Internal(String), +} + +impl SparkError { + /// Serialize this error to JSON format for JNI transfer + pub fn to_json(&self) -> String { + let error_class = self.error_class().unwrap_or(""); + + // Create a JSON structure with errorType, errorClass, and params + match serde_json::to_string(&serde_json::json!({ + "errorType": self.error_type_name(), + "errorClass": error_class, + "params": self.params_as_json(), + })) { + Ok(json) => json, + Err(e) => { + // Fallback if serialization fails + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + } + } + } + + /// Get the error type name for JSON serialization + fn error_type_name(&self) -> &'static str { + match self { + SparkError::CastInvalidValue { .. } => "CastInvalidValue", + SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", + SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", + SparkError::CastOverFlow { .. } => "CastOverFlow", + SparkError::CannotParseDecimal => "CannotParseDecimal", + SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", + SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", + SparkError::DecimalSumOverflow => "DecimalSumOverflow", + SparkError::DivideByZero => "DivideByZero", + SparkError::RemainderByZero => "RemainderByZero", + SparkError::IntervalDividedByZero => "IntervalDividedByZero", + SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + "IntervalArithmeticOverflowWithSuggestion" + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + "IntervalArithmeticOverflowWithoutSuggestion" + } + SparkError::DatetimeOverflow => "DatetimeOverflow", + SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", + SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", + SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", + SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", + SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", + SparkError::NullMapKey => "NullMapKey", + SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", + SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", + SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", + SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", + SparkError::ValueIsNull { .. } => "ValueIsNull", + SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", + SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", + SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", + SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", + SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", + SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", + SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", + SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", + SparkError::Arrow(_) => "Arrow", + SparkError::Internal(_) => "Internal", + } + } + + /// Extract parameters as JSON value + fn params_as_json(&self) -> serde_json::Value { + match self { + SparkError::CastInvalidValue { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::NumericValueOutOfRange { + value, + precision, + scale, + } => { + serde_json::json!({ + "value": value, + "precision": precision, + "scale": scale, + }) + } + SparkError::NumericOutOfRange { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::CastOverFlow { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::ArithmeticOverflow { from_type } => { + serde_json::json!({ + "fromType": from_type, + }) + } + SparkError::BinaryArithmeticOverflow { + value1, + symbol, + value2, + function_name, + } => { + serde_json::json!({ + "value1": value1, + "symbol": symbol, + "value2": value2, + "functionName": function_name, + }) + } + SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { + serde_json::json!({ + "functionName": function_name, + }) + } + SparkError::InvalidArrayIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidElementAtIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidBitmapPosition { + bit_position, + bitmap_num_bytes, + bitmap_num_bits, + } => { + serde_json::json!({ + "bitPosition": bit_position, + "bitmapNumBytes": bitmap_num_bytes, + "bitmapNumBits": bitmap_num_bits, + }) + } + SparkError::DuplicatedMapKey { key } => { + serde_json::json!({ + "key": key, + }) + } + SparkError::ExceedMapSizeLimit { size, max_size } => { + serde_json::json!({ + "size": size, + "maxSize": max_size, + }) + } + SparkError::CollectionSizeLimitExceeded { + num_elements, + max_elements, + } => { + serde_json::json!({ + "numElements": num_elements, + "maxElements": max_elements, + }) + } + SparkError::NotNullAssertViolation { field_name } => { + serde_json::json!({ + "fieldName": field_name, + }) + } + SparkError::ValueIsNull { + field_name, + row_index, + } => { + serde_json::json!({ + "fieldName": field_name, + "rowIndex": row_index, + }) + } + SparkError::CannotParseTimestamp { + message, + suggested_func, + } => { + serde_json::json!({ + "message": message, + "suggestedFunc": suggested_func, + }) + } + SparkError::InvalidFractionOfSecond { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::InvalidUtf8String { hex_string } => { + serde_json::json!({ + "hexString": hex_string, + }) + } + SparkError::UnexpectedPositiveValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::UnexpectedNegativeValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::InvalidRegexGroupIndex { + function_name, + group_count, + group_index, + } => { + serde_json::json!({ + "functionName": function_name, + "groupCount": group_count, + "groupIndex": group_index, + }) + } + SparkError::DatatypeCannotOrder { data_type } => { + serde_json::json!({ + "dataType": data_type, + }) + } + SparkError::Arrow(e) => { + serde_json::json!({ + "message": e.to_string(), + }) + } + SparkError::Internal(msg) => { + serde_json::json!({ + "message": msg, + }) + } + // Simple errors with no parameters + _ => serde_json::json!({}), + } + } + + /// Returns the appropriate Spark exception class for this error + pub fn exception_class(&self) -> &'static str { + match self { + // ArithmeticException + SparkError::DivideByZero + | SparkError::RemainderByZero + | SparkError::IntervalDividedByZero + | SparkError::NumericValueOutOfRange { .. } + | SparkError::NumericOutOfRange { .. } // Comet-specific extension + | SparkError::ArithmeticOverflow { .. } + | SparkError::IntegralDivideOverflow + | SparkError::DecimalSumOverflow + | SparkError::BinaryArithmeticOverflow { .. } + | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } + | SparkError::IntervalArithmeticOverflowWithoutSuggestion + | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", + + // CastOverflow gets special handling with CastOverflowException + SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", + + // NumberFormatException (for cast invalid input errors) + SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", + + // ArrayIndexOutOfBoundsException + SparkError::InvalidArrayIndex { .. } + | SparkError::InvalidElementAtIndex { .. } + | SparkError::InvalidBitmapPosition { .. } + | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", + + // RuntimeException + SparkError::CannotParseDecimal + | SparkError::DuplicatedMapKey { .. } + | SparkError::NullMapKey + | SparkError::MapKeyValueDiffSizes + | SparkError::ExceedMapSizeLimit { .. } + | SparkError::CollectionSizeLimitExceeded { .. } + | SparkError::NotNullAssertViolation { .. } + | SparkError::ValueIsNull { .. } // Comet-specific extension + | SparkError::UnexpectedPositiveValue { .. } + | SparkError::UnexpectedNegativeValue { .. } + | SparkError::InvalidRegexGroupIndex { .. } + | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", + + // DateTimeException + SparkError::CannotParseTimestamp { .. } + | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", + + // IllegalArgumentException + SparkError::DatatypeCannotOrder { .. } + | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", + + // Generic errors + SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", + } + } + + /// Returns the Spark error class code for this error + fn error_class(&self) -> Option<&'static str> { + match self { + // Cast errors + SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), + SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), + SparkError::NumericValueOutOfRange { .. } => { + Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") + } + SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), + SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), + + // Arithmetic errors + SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), + SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), + SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), + SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), + SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), + + // Array index errors + SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), + SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), + SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), + SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), + + // Map/Collection errors + SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), + SparkError::NullMapKey => Some("NULL_MAP_KEY"), + SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), + SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), + SparkError::CollectionSizeLimitExceeded { .. } => { + Some("COLLECTION_SIZE_LIMIT_EXCEEDED") + } + + // Null validation errors + SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), + SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), + + // DateTime errors + SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), + SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), + + // String/UTF8 errors + SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), + + // Function parameter errors + SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), + SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), + + // Regex errors + SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), + + // Unsupported operation errors + SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), + + // Subquery errors + SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), + + // Generic errors (no error class) + SparkError::Arrow(_) | SparkError::Internal(_) => None, + } + } +} + +pub type SparkResult = Result; /// Convert decimal overflow to SparkError::NumericValueOutOfRange. pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { @@ -28,3 +575,285 @@ pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkErr scale, } } + +/// Wrapper that adds QueryContext to SparkError +/// +/// This allows attaching SQL context information (query text, line/position, object name) to errors +#[derive(Debug, Clone)] +pub struct SparkErrorWithContext { + /// The underlying SparkError + pub error: SparkError, + /// Optional QueryContext for SQL location information + pub context: Option>, +} + +impl SparkErrorWithContext { + /// Create a SparkErrorWithContext without context + pub fn new(error: SparkError) -> Self { + Self { + error, + context: None, + } + } + + /// Create a SparkErrorWithContext with QueryContext + pub fn with_context(error: SparkError, context: Arc) -> Self { + Self { + error, + context: Some(context), + } + } + + /// Serialize to JSON including optional context field + /// + /// JSON structure: + /// ```json + /// { + /// "errorType": "DivideByZero", + /// "errorClass": "DIVIDE_BY_ZERO", + /// "params": {}, + /// "context": { + /// "sqlText": "SELECT a/b FROM t", + /// "startIndex": 7, + /// "stopIndex": 9, + /// "line": 1, + /// "startPosition": 7 + /// }, + /// "summary": "== SQL (line 1, position 8) ==\n..." + /// } + /// ``` + pub fn to_json(&self) -> String { + let mut json_obj = serde_json::json!({ + "errorType": self.error.error_type_name(), + "errorClass": self.error.error_class().unwrap_or(""), + "params": self.error.params_as_json(), + }); + + if let Some(ctx) = &self.context { + // Serialize context fields + json_obj["context"] = serde_json::json!({ + "sqlText": ctx.sql_text.as_str(), + "startIndex": ctx.start_index, + "stopIndex": ctx.stop_index, + "objectType": ctx.object_type, + "objectName": ctx.object_name, + "line": ctx.line, + "startPosition": ctx.start_position, + }); + + // Add formatted summary + json_obj["summary"] = serde_json::json!(ctx.format_summary()); + } + + serde_json::to_string(&json_obj).unwrap_or_else(|e| { + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + }) + } +} + +impl std::fmt::Display for SparkErrorWithContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.error)?; + if let Some(ctx) = &self.context { + write!(f, "\n{}", ctx.format_summary())?; + } + Ok(()) + } +} + +impl std::error::Error for SparkErrorWithContext {} + +impl From for SparkErrorWithContext { + fn from(error: SparkError) -> Self { + SparkErrorWithContext::new(error) + } +} + +impl From for DataFusionError { + fn from(value: SparkErrorWithContext) -> Self { + DataFusionError::External(Box::new(value)) + } +} + +impl From for SparkError { + fn from(value: ArrowError) -> Self { + SparkError::Arrow(Arc::new(value)) + } +} + +impl From for DataFusionError { + fn from(value: SparkError) -> Self { + DataFusionError::External(Box::new(value)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_divide_by_zero_json() { + let error = SparkError::DivideByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"DivideByZero\"")); + assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DivideByZero"); + assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); + } + + #[test] + fn test_remainder_by_zero_json() { + let error = SparkError::RemainderByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"RemainderByZero\"")); + assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); + } + + #[test] + fn test_binary_overflow_json() { + let error = SparkError::BinaryArithmeticOverflow { + value1: "32767".to_string(), + symbol: "+".to_string(), + value2: "1".to_string(), + function_name: "try_add".to_string(), + }; + let json = error.to_json(); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); + assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); + assert_eq!(parsed["params"]["value1"], "32767"); + assert_eq!(parsed["params"]["symbol"], "+"); + assert_eq!(parsed["params"]["value2"], "1"); + assert_eq!(parsed["params"]["functionName"], "try_add"); + } + + #[test] + fn test_invalid_array_index_json() { + let error = SparkError::InvalidArrayIndex { + index_value: 10, + array_size: 3, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "InvalidArrayIndex"); + assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); + assert_eq!(parsed["params"]["indexValue"], 10); + assert_eq!(parsed["params"]["arraySize"], 3); + } + + #[test] + fn test_numeric_value_out_of_range_json() { + let error = SparkError::NumericValueOutOfRange { + value: "999.99".to_string(), + precision: 5, + scale: 2, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); + assert_eq!( + parsed["errorClass"], + "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" + ); + assert_eq!(parsed["params"]["value"], "999.99"); + assert_eq!(parsed["params"]["precision"], 5); + assert_eq!(parsed["params"]["scale"], 2); + } + + #[test] + fn test_cast_invalid_value_json() { + let error = SparkError::CastInvalidValue { + value: "abc".to_string(), + from_type: "STRING".to_string(), + to_type: "INT".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "CastInvalidValue"); + assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); + assert_eq!(parsed["params"]["value"], "abc"); + assert_eq!(parsed["params"]["fromType"], "STRING"); + assert_eq!(parsed["params"]["toType"], "INT"); + } + + #[test] + fn test_duplicated_map_key_json() { + let error = SparkError::DuplicatedMapKey { + key: "duplicate_key".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DuplicatedMapKey"); + assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); + assert_eq!(parsed["params"]["key"], "duplicate_key"); + } + + #[test] + fn test_null_map_key_json() { + let error = SparkError::NullMapKey; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NullMapKey"); + assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); + // Params should be an empty object + assert_eq!(parsed["params"], serde_json::json!({})); + } + + #[test] + fn test_error_class_mapping() { + // Test that error_class() returns the correct error class + assert_eq!( + SparkError::DivideByZero.error_class(), + Some("DIVIDE_BY_ZERO") + ); + assert_eq!( + SparkError::RemainderByZero.error_class(), + Some("REMAINDER_BY_ZERO") + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .error_class(), + Some("INVALID_ARRAY_INDEX") + ); + assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); + } + + #[test] + fn test_exception_class_mapping() { + // Test that exception_class() returns the correct Java exception class + assert_eq!( + SparkError::DivideByZero.exception_class(), + "org/apache/spark/SparkArithmeticException" + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .exception_class(), + "org/apache/spark/SparkArrayIndexOutOfBoundsException" + ); + assert_eq!( + SparkError::NullMapKey.exception_class(), + "org/apache/spark/SparkRuntimeException" + ); + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 072fa1fad7..d576c803a1 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -74,7 +74,7 @@ pub use datetime_funcs::{ SparkDateDiff, SparkDateTrunc, SparkHour, SparkMakeDate, SparkMinute, SparkSecond, SparkUnixTimestamp, TimestampTruncExpr, }; -pub use error::{SparkError, SparkErrorWithContext, SparkResult}; +pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ diff --git a/native/spark-expr/src/query_context.rs b/native/spark-expr/src/query_context.rs index e7ae1e01a9..10ecef6550 100644 --- a/native/spark-expr/src/query_context.rs +++ b/native/spark-expr/src/query_context.rs @@ -15,7 +15,389 @@ // specific language governing permissions and limitations // under the License. -// Re-export QueryContext types from jvm-bridge crate -pub use datafusion_comet_jvm_bridge::query_context::{ - create_query_context_map, QueryContext, QueryContextMap, -}; +//! Query execution context for error reporting +//! +//! This module provides QueryContext which mirrors Spark's SQLQueryContext +//! for providing SQL text, line/position information, and error location +//! pointers in exception messages. + +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// Based on Spark's SQLQueryContext for error reporting. +/// +/// Contains information about where an error occurred in a SQL query, +/// including the full SQL text, line/column positions, and object context. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct QueryContext { + /// Full SQL query text + #[serde(rename = "sqlText")] + pub sql_text: Arc, + + /// Start offset in SQL text (0-based, character index) + #[serde(rename = "startIndex")] + pub start_index: i32, + + /// Stop offset in SQL text (0-based, character index, inclusive) + #[serde(rename = "stopIndex")] + pub stop_index: i32, + + /// Object type (e.g., "VIEW", "Project", "Filter") + #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] + pub object_type: Option, + + /// Object name (e.g., view name, column name) + #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] + pub object_name: Option, + + /// Line number in SQL query (1-based) + pub line: i32, + + /// Column position within the line (0-based) + #[serde(rename = "startPosition")] + pub start_position: i32, +} + +impl QueryContext { + #[allow(clippy::too_many_arguments)] + pub fn new( + sql_text: String, + start_index: i32, + stop_index: i32, + object_type: Option, + object_name: Option, + line: i32, + start_position: i32, + ) -> Self { + Self { + sql_text: Arc::new(sql_text), + start_index, + stop_index, + object_type, + object_name, + line, + start_position, + } + } + + /// Convert a character index to a byte offset in the SQL text. + /// Returns None if the character index is out of range. + fn char_index_to_byte_offset(&self, char_index: usize) -> Option { + self.sql_text + .char_indices() + .nth(char_index) + .map(|(byte_offset, _)| byte_offset) + } + + /// Generate a summary string showing SQL fragment with error location. + /// (From SQLQueryContext.summary) + /// + /// Format example: + /// ```text + /// == SQL of VIEW v1 (line 1, position 8) == + /// SELECT a/b FROM t + /// ^^^ + /// ``` + pub fn format_summary(&self) -> String { + let start_char = self.start_index.max(0) as usize; + // stop_index is inclusive; fragment covers [start, stop] + let stop_char = (self.stop_index + 1).max(0) as usize; + + let fragment = match ( + self.char_index_to_byte_offset(start_char), + // stop_char may equal sql_text.chars().count() (one past the end) + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], + _ => "", + }; + + // Build the header line + let mut summary = String::from("== SQL"); + + if let Some(obj_type) = &self.object_type { + if !obj_type.is_empty() { + summary.push_str(" of "); + summary.push_str(obj_type); + + if let Some(obj_name) = &self.object_name { + if !obj_name.is_empty() { + summary.push(' '); + summary.push_str(obj_name); + } + } + } + } + + summary.push_str(&format!( + " (line {}, position {}) ==\n", + self.line, + self.start_position + 1 // Convert 0-based to 1-based for display + )); + + // Add the SQL text with fragment highlighted + summary.push_str(&self.sql_text); + summary.push('\n'); + + // Add caret pointer + let caret_position = self.start_position.max(0) as usize; + summary.push_str(&" ".repeat(caret_position)); + // fragment.chars().count() gives the correct display width for non-ASCII + summary.push_str(&"^".repeat(fragment.chars().count().max(1))); + + summary + } + + /// Returns the SQL fragment that caused the error. + #[cfg(test)] + fn fragment(&self) -> String { + let start_char = self.start_index.max(0) as usize; + let stop_char = (self.stop_index + 1).max(0) as usize; + + match ( + self.char_index_to_byte_offset(start_char), + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), + _ => String::new(), + } + } +} + +use std::collections::HashMap; +use std::sync::RwLock; + +/// Map that stores QueryContext information for expressions during execution. +/// +/// This map is populated during plan deserialization and accessed +/// during error creation to attach SQL context to exceptions. +#[derive(Debug)] +pub struct QueryContextMap { + /// Map from expression ID to QueryContext + contexts: RwLock>>, +} + +impl QueryContextMap { + pub fn new() -> Self { + Self { + contexts: RwLock::new(HashMap::new()), + } + } + + /// Register a QueryContext for an expression ID. + /// + /// If the expression ID already exists, it will be replaced. + /// + /// # Arguments + /// * `expr_id` - Unique expression identifier from protobuf + /// * `context` - QueryContext containing SQL text and position info + pub fn register(&self, expr_id: u64, context: QueryContext) { + let mut contexts = self.contexts.write().unwrap(); + contexts.insert(expr_id, Arc::new(context)); + } + + /// Get the QueryContext for an expression ID. + /// + /// Returns None if no context is registered for this expression. + /// + /// # Arguments + /// * `expr_id` - Expression identifier to look up + pub fn get(&self, expr_id: u64) -> Option> { + let contexts = self.contexts.read().unwrap(); + contexts.get(&expr_id).cloned() + } + + /// Clear all registered contexts. + /// + /// This is typically called after plan execution completes to free memory. + pub fn clear(&self) { + let mut contexts = self.contexts.write().unwrap(); + contexts.clear(); + } + + /// Return the number of registered contexts (for debugging/testing) + pub fn len(&self) -> usize { + let contexts = self.contexts.read().unwrap(); + contexts.len() + } + + /// Check if the map is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Default for QueryContextMap { + fn default() -> Self { + Self::new() + } +} + +/// Create a new session-scoped QueryContextMap. +/// +/// This should be called once per SessionContext during plan creation +/// and passed to expressions that need query context for error reporting. +pub fn create_query_context_map() -> Arc { + Arc::new(QueryContextMap::new()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query_context_creation() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); + assert_eq!(ctx.start_index, 7); + assert_eq!(ctx.stop_index, 9); + assert_eq!(ctx.object_type, Some("Divide".to_string())); + assert_eq!(ctx.object_name, Some("a/b".to_string())); + assert_eq!(ctx.line, 1); + assert_eq!(ctx.start_position, 7); + } + + #[test] + fn test_query_context_serialization() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + let json = serde_json::to_string(&ctx).unwrap(); + let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); + + assert_eq!(ctx, deserialized); + } + + #[test] + fn test_format_summary() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("VIEW".to_string()), + Some("v1".to_string()), + 1, + 7, + ); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + assert!(summary.contains("^^^")); // Three carets for "a/b" + } + + #[test] + fn test_format_summary_without_object() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + } + + #[test] + fn test_fragment() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + assert_eq!(ctx.fragment(), "a/b"); + } + + #[test] + fn test_arc_string_sharing() { + let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let ctx2 = ctx1.clone(); + + // Arc should share the same allocation + assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); + } + + #[test] + fn test_json_with_optional_fields() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let json = serde_json::to_string(&ctx).unwrap(); + + // Should not serialize objectType and objectName when None + assert!(!json.contains("objectType")); + assert!(!json.contains("objectName")); + } + + #[test] + fn test_map_register_and_get() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx.clone()); + + let retrieved = map.get(1).unwrap(); + assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); + assert_eq!(retrieved.start_index, 7); + } + + #[test] + fn test_map_get_nonexistent() { + let map = QueryContextMap::new(); + assert!(map.get(999).is_none()); + } + + #[test] + fn test_map_clear() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx); + assert_eq!(map.len(), 1); + + map.clear(); + assert_eq!(map.len(), 0); + assert!(map.is_empty()); + } + + // Verify that fragment() and format_summary() correctly handle SQL text that + // contains multi-byte characters + + #[test] + fn test_fragment_non_ascii_accented() { + // "é" is a 2-byte UTF-8 sequence (U+00E9). + // SQL: "SELECT café FROM t" + // 0123456789... + // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. + // start_index=7, stop_index=10 should yield "café" + let sql = "SELECT café FROM t".to_string(); + let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); + assert_eq!(ctx.fragment(), "café"); + } +} From 8c0e1157f197d991da8b9275980052f7e40e8ec2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 16:45:14 -0600 Subject: [PATCH 06/11] trigger CI From 8e81368b2496d95f81d4ea94bdd7a16f71d29c4c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 18:06:08 -0600 Subject: [PATCH 07/11] refactor: extract spark-errors crate to replace callback pattern Move SparkError, SparkErrorWithContext, QueryContext, and QueryContextMap into a new spark-errors crate that both jvm-bridge and spark-expr depend on. This allows jvm-bridge to directly downcast SparkError variants in throw_exception, eliminating the ExternalErrorHandler callback, OnceCell, and init-time registration that were previously needed. --- native/Cargo.lock | 13 + native/Cargo.toml | 5 +- native/core/src/lib.rs | 30 - native/jvm-bridge/Cargo.toml | 1 + native/jvm-bridge/src/errors.rs | 53 +- native/spark-errors/Cargo.toml | 40 ++ native/spark-errors/src/error.rs | 842 ++++++++++++++++++++++ native/spark-errors/src/lib.rs | 22 + native/spark-errors/src/query_context.rs | 403 +++++++++++ native/spark-expr/Cargo.toml | 1 + native/spark-expr/src/error.rs | 846 +---------------------- native/spark-expr/src/query_context.rs | 388 +---------- 12 files changed, 1354 insertions(+), 1290 deletions(-) create mode 100644 native/spark-errors/Cargo.toml create mode 100644 native/spark-errors/src/error.rs create mode 100644 native/spark-errors/src/lib.rs create mode 100644 native/spark-errors/src/query_context.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index 065d67f9c9..f4531b36e8 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1905,6 +1905,7 @@ dependencies = [ "arrow", "assertables", "datafusion", + "datafusion-comet-spark-errors", "jni", "lazy_static", "once_cell", @@ -1937,6 +1938,17 @@ dependencies = [ "prost-build", ] +[[package]] +name = "datafusion-comet-spark-errors" +version = "0.14.0" +dependencies = [ + "arrow", + "datafusion", + "serde", + "serde_json", + "thiserror 2.0.18", +] + [[package]] name = "datafusion-comet-spark-expr" version = "0.14.0" @@ -1947,6 +1959,7 @@ dependencies = [ "chrono-tz", "criterion", "datafusion", + "datafusion-comet-spark-errors", "futures", "hex", "num", diff --git a/native/Cargo.toml b/native/Cargo.toml index 847045a9eb..960fb56509 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [workspace] -default-members = ["core", "spark-expr", "proto", "jvm-bridge"] -members = ["core", "spark-expr", "proto", "jvm-bridge", "hdfs", "fs-hdfs"] +default-members = ["core", "spark-expr", "spark-errors", "proto", "jvm-bridge"] +members = ["core", "spark-expr", "spark-errors", "proto", "jvm-bridge", "hdfs", "fs-hdfs"] resolver = "2" [workspace.package] @@ -43,6 +43,7 @@ datafusion-datasource = { version = "52.2.0" } datafusion-physical-expr-adapter = { version = "52.2.0" } datafusion-spark = { version = "52.2.0" } datafusion-comet-spark-expr = { path = "spark-expr" } +datafusion-comet-spark-errors = { path = "spark-errors" } datafusion-comet-jvm-bridge = { path = "jvm-bridge" } datafusion-comet-proto = { path = "proto" } chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index c7473856ba..8ee3139ff9 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -91,9 +91,6 @@ pub extern "system" fn Java_org_apache_comet_NativeBase_init( // Initialize the error handling to capture panic backtraces errors::init(); - // Register SparkError handler for JNI exception throwing - errors::register_external_error_handler(handle_spark_error); - try_unwrap_or_throw(&e, |mut env| { let path: String = env.get_string(&log_conf_path)?.into(); @@ -122,33 +119,6 @@ pub extern "system" fn Java_org_apache_comet_NativeBase_init( }) } -/// Handle SparkError variants in DataFusionError::External for JNI exception throwing. -/// Returns true if the error was handled. -fn handle_spark_error( - error: &(dyn std::error::Error + Send + Sync + 'static), - env: &mut JNIEnv, -) -> bool { - use datafusion_comet_spark_expr::{SparkError, SparkErrorWithContext}; - - if let Some(spark_error_with_ctx) = error.downcast_ref::() { - let json_message = spark_error_with_ctx.to_json(); - let _ = env.throw_new( - "org/apache/comet/exceptions/CometQueryExecutionException", - json_message, - ); - return true; - } - if let Some(spark_error) = error.downcast_ref::() { - let json_message = spark_error.to_json(); - let _ = env.throw_new( - "org/apache/comet/exceptions/CometQueryExecutionException", - json_message, - ); - return true; - } - false -} - const LOG_PATTERN: &str = "{d(%y/%m/%d %H:%M:%S)} {l} {f}: {m}{n}"; /// JNI method to check if a specific feature is enabled in the native Rust code. diff --git a/native/jvm-bridge/Cargo.toml b/native/jvm-bridge/Cargo.toml index 07ddc3f883..a869dd1eee 100644 --- a/native/jvm-bridge/Cargo.toml +++ b/native/jvm-bridge/Cargo.toml @@ -39,6 +39,7 @@ lazy_static = "1.4.0" once_cell = "1.18.0" paste = "1.0.14" prost = "0.14.3" +datafusion-comet-spark-errors = { workspace = true } [dev-dependencies] jni = { version = "0.21", features = ["invocation"] } diff --git a/native/jvm-bridge/src/errors.rs b/native/jvm-bridge/src/errors.rs index a71ee9105a..62ce3b05fb 100644 --- a/native/jvm-bridge/src/errors.rs +++ b/native/jvm-bridge/src/errors.rs @@ -19,6 +19,7 @@ use arrow::error::ArrowError; use datafusion::common::DataFusionError; +use datafusion_comet_spark_errors::{SparkError, SparkErrorWithContext}; use jni::errors::{Exception, ToException}; use regex::Regex; @@ -40,25 +41,9 @@ use jni::sys::{jboolean, jbyte, jchar, jdouble, jfloat, jint, jlong, jobject, js use jni::objects::{GlobalRef, JThrowable}; use jni::JNIEnv; use lazy_static::lazy_static; -use once_cell::sync::OnceCell; use parquet::errors::ParquetError; use thiserror::Error; -/// Handler for DataFusionError::External errors during JNI exception throwing. -/// Returns true if the error was handled (exception thrown), false to fall through -/// to the default handler. -pub type ExternalErrorHandler = - fn(error: &(dyn std::error::Error + Send + Sync + 'static), env: &mut JNIEnv) -> bool; - -static EXTERNAL_ERROR_HANDLER: OnceCell = OnceCell::new(); - -/// Register a handler for DataFusionError::External errors. -/// This allows the core crate to register SparkError-specific handling -/// without the jvm-bridge crate needing to know about SparkError. -pub fn register_external_error_handler(handler: ExternalErrorHandler) { - let _ = EXTERNAL_ERROR_HANDLER.set(handler); -} - lazy_static! { static ref PANIC_BACKTRACE: Arc>> = Arc::new(Mutex::new(None)); } @@ -489,26 +474,34 @@ fn throw_exception(env: &mut JNIEnv, error: &CometError, backtrace: Option env.throw(<&JThrowable>::from(throwable.as_obj())), - // Handle DataFusion errors containing external errors (e.g. SparkError) + // Handle DataFusion errors containing SparkError or SparkErrorWithContext CometError::DataFusion { msg: _, source: DataFusionError::External(e), } => { - // Try registered handler first (e.g. SparkError handling from core) - if let Some(handler) = EXTERNAL_ERROR_HANDLER.get() { - if handler(e.as_ref(), env) { - return; + if let Some(spark_error_with_ctx) = e.downcast_ref::() { + let json_message = spark_error_with_ctx.to_json(); + env.throw_new( + "org/apache/comet/exceptions/CometQueryExecutionException", + json_message, + ) + } else if let Some(spark_error) = e.downcast_ref::() { + let json_message = spark_error.to_json(); + env.throw_new( + "org/apache/comet/exceptions/CometQueryExecutionException", + json_message, + ) + } else { + // Fall through to generic exception + let exception = error.to_exception(); + match backtrace { + Some(backtrace_string) => env.throw_new( + exception.class, + to_stacktrace_string(exception.msg, backtrace_string).unwrap(), + ), + _ => env.throw_new(exception.class, exception.msg), } } - // Fall through to generic exception - let exception = error.to_exception(); - match backtrace { - Some(backtrace_string) => env.throw_new( - exception.class, - to_stacktrace_string(exception.msg, backtrace_string).unwrap(), - ), - _ => env.throw_new(exception.class, exception.msg), - } } _ => { let exception = error.to_exception(); diff --git a/native/spark-errors/Cargo.toml b/native/spark-errors/Cargo.toml new file mode 100644 index 0000000000..1b0e40053f --- /dev/null +++ b/native/spark-errors/Cargo.toml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-comet-spark-errors" +description = "Apache DataFusion Comet: Spark error types" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +publish = false + +[dependencies] +arrow = { workspace = true } +datafusion = { workspace = true } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = { workspace = true } + +[lib] +name = "datafusion_comet_spark_errors" +path = "src/lib.rs" diff --git a/native/spark-errors/src/error.rs b/native/spark-errors/src/error.rs new file mode 100644 index 0000000000..e36f069ac2 --- /dev/null +++ b/native/spark-errors/src/error.rs @@ -0,0 +1,842 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::error::ArrowError; +use datafusion::common::DataFusionError; +use std::sync::Arc; + +#[derive(thiserror::Error, Debug, Clone)] +pub enum SparkError { + // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastInvalidValue { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + + #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")] + NumericOutOfRange { value: String }, + + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastOverFlow { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] + CannotParseDecimal, + + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntegralDivideOverflow, + + #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DecimalSumOverflow, + + #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + DivideByZero, + + #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + RemainderByZero, + + #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] + IntervalDividedByZero, + + #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + BinaryArithmeticOverflow { + value1: String, + symbol: String, + value2: String, + function_name: String, + }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithSuggestion { function_name: String }, + + #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + IntervalArithmeticOverflowWithoutSuggestion, + + #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] + DatetimeOverflow, + + #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidArrayIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidElementAtIndex { index_value: i32, array_size: i32 }, + + #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] + InvalidBitmapPosition { + bit_position: i64, + bitmap_num_bytes: i64, + bitmap_num_bits: i64, + }, + + #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] + InvalidIndexOfZero, + + #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] + DuplicatedMapKey { key: String }, + + #[error("[NULL_MAP_KEY] Cannot use null as map key.")] + NullMapKey, + + #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] + MapKeyValueDiffSizes, + + #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] + ExceedMapSizeLimit { size: i32, max_size: i32 }, + + #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] + CollectionSizeLimitExceeded { + num_elements: i64, + max_elements: i64, + }, + + #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] + NotNullAssertViolation { field_name: String }, + + #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] + ValueIsNull { field_name: String, row_index: i32 }, + + #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] + CannotParseTimestamp { + message: String, + suggested_func: String, + }, + + #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + InvalidFractionOfSecond { value: f64 }, + + #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] + InvalidUtf8String { hex_string: String }, + + #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] + UnexpectedPositiveValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] + UnexpectedNegativeValue { + parameter_name: String, + actual_value: i32, + }, + + #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] + InvalidRegexGroupIndex { + function_name: String, + group_count: i32, + group_index: i32, + }, + + #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] + DatatypeCannotOrder { data_type: String }, + + #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] + ScalarSubqueryTooManyRows, + + #[error("ArrowError: {0}.")] + Arrow(Arc), + + #[error("InternalError: {0}.")] + Internal(String), +} + +impl SparkError { + /// Serialize this error to JSON format for JNI transfer + pub fn to_json(&self) -> String { + let error_class = self.error_class().unwrap_or(""); + + // Create a JSON structure with errorType, errorClass, and params + match serde_json::to_string(&serde_json::json!({ + "errorType": self.error_type_name(), + "errorClass": error_class, + "params": self.params_as_json(), + })) { + Ok(json) => json, + Err(e) => { + // Fallback if serialization fails + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + } + } + } + + /// Get the error type name for JSON serialization + pub(crate) fn error_type_name(&self) -> &'static str { + match self { + SparkError::CastInvalidValue { .. } => "CastInvalidValue", + SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", + SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", + SparkError::CastOverFlow { .. } => "CastOverFlow", + SparkError::CannotParseDecimal => "CannotParseDecimal", + SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", + SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", + SparkError::DecimalSumOverflow => "DecimalSumOverflow", + SparkError::DivideByZero => "DivideByZero", + SparkError::RemainderByZero => "RemainderByZero", + SparkError::IntervalDividedByZero => "IntervalDividedByZero", + SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + "IntervalArithmeticOverflowWithSuggestion" + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + "IntervalArithmeticOverflowWithoutSuggestion" + } + SparkError::DatetimeOverflow => "DatetimeOverflow", + SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", + SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", + SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", + SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", + SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", + SparkError::NullMapKey => "NullMapKey", + SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", + SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", + SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", + SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", + SparkError::ValueIsNull { .. } => "ValueIsNull", + SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", + SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", + SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", + SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", + SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", + SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", + SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", + SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", + SparkError::Arrow(_) => "Arrow", + SparkError::Internal(_) => "Internal", + } + } + + /// Extract parameters as JSON value + pub(crate) fn params_as_json(&self) -> serde_json::Value { + match self { + SparkError::CastInvalidValue { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::NumericValueOutOfRange { + value, + precision, + scale, + } => { + serde_json::json!({ + "value": value, + "precision": precision, + "scale": scale, + }) + } + SparkError::NumericOutOfRange { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::CastOverFlow { + value, + from_type, + to_type, + } => { + serde_json::json!({ + "value": value, + "fromType": from_type, + "toType": to_type, + }) + } + SparkError::ArithmeticOverflow { from_type } => { + serde_json::json!({ + "fromType": from_type, + }) + } + SparkError::BinaryArithmeticOverflow { + value1, + symbol, + value2, + function_name, + } => { + serde_json::json!({ + "value1": value1, + "symbol": symbol, + "value2": value2, + "functionName": function_name, + }) + } + SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { + serde_json::json!({ + "functionName": function_name, + }) + } + SparkError::InvalidArrayIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidElementAtIndex { + index_value, + array_size, + } => { + serde_json::json!({ + "indexValue": index_value, + "arraySize": array_size, + }) + } + SparkError::InvalidBitmapPosition { + bit_position, + bitmap_num_bytes, + bitmap_num_bits, + } => { + serde_json::json!({ + "bitPosition": bit_position, + "bitmapNumBytes": bitmap_num_bytes, + "bitmapNumBits": bitmap_num_bits, + }) + } + SparkError::DuplicatedMapKey { key } => { + serde_json::json!({ + "key": key, + }) + } + SparkError::ExceedMapSizeLimit { size, max_size } => { + serde_json::json!({ + "size": size, + "maxSize": max_size, + }) + } + SparkError::CollectionSizeLimitExceeded { + num_elements, + max_elements, + } => { + serde_json::json!({ + "numElements": num_elements, + "maxElements": max_elements, + }) + } + SparkError::NotNullAssertViolation { field_name } => { + serde_json::json!({ + "fieldName": field_name, + }) + } + SparkError::ValueIsNull { + field_name, + row_index, + } => { + serde_json::json!({ + "fieldName": field_name, + "rowIndex": row_index, + }) + } + SparkError::CannotParseTimestamp { + message, + suggested_func, + } => { + serde_json::json!({ + "message": message, + "suggestedFunc": suggested_func, + }) + } + SparkError::InvalidFractionOfSecond { value } => { + serde_json::json!({ + "value": value, + }) + } + SparkError::InvalidUtf8String { hex_string } => { + serde_json::json!({ + "hexString": hex_string, + }) + } + SparkError::UnexpectedPositiveValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::UnexpectedNegativeValue { + parameter_name, + actual_value, + } => { + serde_json::json!({ + "parameterName": parameter_name, + "actualValue": actual_value, + }) + } + SparkError::InvalidRegexGroupIndex { + function_name, + group_count, + group_index, + } => { + serde_json::json!({ + "functionName": function_name, + "groupCount": group_count, + "groupIndex": group_index, + }) + } + SparkError::DatatypeCannotOrder { data_type } => { + serde_json::json!({ + "dataType": data_type, + }) + } + SparkError::Arrow(e) => { + serde_json::json!({ + "message": e.to_string(), + }) + } + SparkError::Internal(msg) => { + serde_json::json!({ + "message": msg, + }) + } + // Simple errors with no parameters + _ => serde_json::json!({}), + } + } + + /// Returns the appropriate Spark exception class for this error + pub fn exception_class(&self) -> &'static str { + match self { + // ArithmeticException + SparkError::DivideByZero + | SparkError::RemainderByZero + | SparkError::IntervalDividedByZero + | SparkError::NumericValueOutOfRange { .. } + | SparkError::NumericOutOfRange { .. } // Comet-specific extension + | SparkError::ArithmeticOverflow { .. } + | SparkError::IntegralDivideOverflow + | SparkError::DecimalSumOverflow + | SparkError::BinaryArithmeticOverflow { .. } + | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } + | SparkError::IntervalArithmeticOverflowWithoutSuggestion + | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", + + // CastOverflow gets special handling with CastOverflowException + SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", + + // NumberFormatException (for cast invalid input errors) + SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", + + // ArrayIndexOutOfBoundsException + SparkError::InvalidArrayIndex { .. } + | SparkError::InvalidElementAtIndex { .. } + | SparkError::InvalidBitmapPosition { .. } + | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", + + // RuntimeException + SparkError::CannotParseDecimal + | SparkError::DuplicatedMapKey { .. } + | SparkError::NullMapKey + | SparkError::MapKeyValueDiffSizes + | SparkError::ExceedMapSizeLimit { .. } + | SparkError::CollectionSizeLimitExceeded { .. } + | SparkError::NotNullAssertViolation { .. } + | SparkError::ValueIsNull { .. } // Comet-specific extension + | SparkError::UnexpectedPositiveValue { .. } + | SparkError::UnexpectedNegativeValue { .. } + | SparkError::InvalidRegexGroupIndex { .. } + | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", + + // DateTimeException + SparkError::CannotParseTimestamp { .. } + | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", + + // IllegalArgumentException + SparkError::DatatypeCannotOrder { .. } + | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", + + // Generic errors + SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", + } + } + + /// Returns the Spark error class code for this error + pub(crate) fn error_class(&self) -> Option<&'static str> { + match self { + // Cast errors + SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), + SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), + SparkError::NumericValueOutOfRange { .. } => { + Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") + } + SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), + SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), + + // Arithmetic errors + SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), + SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), + SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), + SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), + SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), + SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), + SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::IntervalArithmeticOverflowWithoutSuggestion => { + Some("INTERVAL_ARITHMETIC_OVERFLOW") + } + SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), + + // Array index errors + SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), + SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), + SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), + SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), + + // Map/Collection errors + SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), + SparkError::NullMapKey => Some("NULL_MAP_KEY"), + SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), + SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), + SparkError::CollectionSizeLimitExceeded { .. } => { + Some("COLLECTION_SIZE_LIMIT_EXCEEDED") + } + + // Null validation errors + SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), + SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), + + // DateTime errors + SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), + SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), + + // String/UTF8 errors + SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), + + // Function parameter errors + SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), + SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), + + // Regex errors + SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), + + // Unsupported operation errors + SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), + + // Subquery errors + SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), + + // Generic errors (no error class) + SparkError::Arrow(_) | SparkError::Internal(_) => None, + } + } +} + +pub type SparkResult = Result; + +/// Convert decimal overflow to SparkError::NumericValueOutOfRange. +pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { + SparkError::NumericValueOutOfRange { + value: value.to_string(), + precision, + scale, + } +} + +/// Wrapper that adds QueryContext to SparkError +/// +/// This allows attaching SQL context information (query text, line/position, object name) to errors +#[derive(Debug, Clone)] +pub struct SparkErrorWithContext { + /// The underlying SparkError + pub error: SparkError, + /// Optional QueryContext for SQL location information + pub context: Option>, +} + +impl SparkErrorWithContext { + /// Create a SparkErrorWithContext without context + pub fn new(error: SparkError) -> Self { + Self { + error, + context: None, + } + } + + /// Create a SparkErrorWithContext with QueryContext + pub fn with_context(error: SparkError, context: Arc) -> Self { + Self { + error, + context: Some(context), + } + } + + /// Serialize to JSON including optional context field + pub fn to_json(&self) -> String { + let mut json_obj = serde_json::json!({ + "errorType": self.error.error_type_name(), + "errorClass": self.error.error_class().unwrap_or(""), + "params": self.error.params_as_json(), + }); + + if let Some(ctx) = &self.context { + // Serialize context fields + json_obj["context"] = serde_json::json!({ + "sqlText": ctx.sql_text.as_str(), + "startIndex": ctx.start_index, + "stopIndex": ctx.stop_index, + "objectType": ctx.object_type, + "objectName": ctx.object_name, + "line": ctx.line, + "startPosition": ctx.start_position, + }); + + // Add formatted summary + json_obj["summary"] = serde_json::json!(ctx.format_summary()); + } + + serde_json::to_string(&json_obj).unwrap_or_else(|e| { + format!( + "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", + e + ) + }) + } +} + +impl std::fmt::Display for SparkErrorWithContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.error)?; + if let Some(ctx) = &self.context { + write!(f, "\n{}", ctx.format_summary())?; + } + Ok(()) + } +} + +impl std::error::Error for SparkErrorWithContext {} + +impl From for SparkErrorWithContext { + fn from(error: SparkError) -> Self { + SparkErrorWithContext::new(error) + } +} + +impl From for DataFusionError { + fn from(value: SparkErrorWithContext) -> Self { + DataFusionError::External(Box::new(value)) + } +} + +impl From for SparkError { + fn from(value: ArrowError) -> Self { + SparkError::Arrow(Arc::new(value)) + } +} + +impl From for DataFusionError { + fn from(value: SparkError) -> Self { + DataFusionError::External(Box::new(value)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_divide_by_zero_json() { + let error = SparkError::DivideByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"DivideByZero\"")); + assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DivideByZero"); + assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); + } + + #[test] + fn test_remainder_by_zero_json() { + let error = SparkError::RemainderByZero; + let json = error.to_json(); + + assert!(json.contains("\"errorType\":\"RemainderByZero\"")); + assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); + } + + #[test] + fn test_binary_overflow_json() { + let error = SparkError::BinaryArithmeticOverflow { + value1: "32767".to_string(), + symbol: "+".to_string(), + value2: "1".to_string(), + function_name: "try_add".to_string(), + }; + let json = error.to_json(); + + // Verify it's valid JSON + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); + assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); + assert_eq!(parsed["params"]["value1"], "32767"); + assert_eq!(parsed["params"]["symbol"], "+"); + assert_eq!(parsed["params"]["value2"], "1"); + assert_eq!(parsed["params"]["functionName"], "try_add"); + } + + #[test] + fn test_invalid_array_index_json() { + let error = SparkError::InvalidArrayIndex { + index_value: 10, + array_size: 3, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "InvalidArrayIndex"); + assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); + assert_eq!(parsed["params"]["indexValue"], 10); + assert_eq!(parsed["params"]["arraySize"], 3); + } + + #[test] + fn test_numeric_value_out_of_range_json() { + let error = SparkError::NumericValueOutOfRange { + value: "999.99".to_string(), + precision: 5, + scale: 2, + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); + assert_eq!( + parsed["errorClass"], + "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" + ); + assert_eq!(parsed["params"]["value"], "999.99"); + assert_eq!(parsed["params"]["precision"], 5); + assert_eq!(parsed["params"]["scale"], 2); + } + + #[test] + fn test_cast_invalid_value_json() { + let error = SparkError::CastInvalidValue { + value: "abc".to_string(), + from_type: "STRING".to_string(), + to_type: "INT".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "CastInvalidValue"); + assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); + assert_eq!(parsed["params"]["value"], "abc"); + assert_eq!(parsed["params"]["fromType"], "STRING"); + assert_eq!(parsed["params"]["toType"], "INT"); + } + + #[test] + fn test_duplicated_map_key_json() { + let error = SparkError::DuplicatedMapKey { + key: "duplicate_key".to_string(), + }; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "DuplicatedMapKey"); + assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); + assert_eq!(parsed["params"]["key"], "duplicate_key"); + } + + #[test] + fn test_null_map_key_json() { + let error = SparkError::NullMapKey; + let json = error.to_json(); + + let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed["errorType"], "NullMapKey"); + assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); + // Params should be an empty object + assert_eq!(parsed["params"], serde_json::json!({})); + } + + #[test] + fn test_error_class_mapping() { + // Test that error_class() returns the correct error class + assert_eq!( + SparkError::DivideByZero.error_class(), + Some("DIVIDE_BY_ZERO") + ); + assert_eq!( + SparkError::RemainderByZero.error_class(), + Some("REMAINDER_BY_ZERO") + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .error_class(), + Some("INVALID_ARRAY_INDEX") + ); + assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); + } + + #[test] + fn test_exception_class_mapping() { + // Test that exception_class() returns the correct Java exception class + assert_eq!( + SparkError::DivideByZero.exception_class(), + "org/apache/spark/SparkArithmeticException" + ); + assert_eq!( + SparkError::InvalidArrayIndex { + index_value: 0, + array_size: 0 + } + .exception_class(), + "org/apache/spark/SparkArrayIndexOutOfBoundsException" + ); + assert_eq!( + SparkError::NullMapKey.exception_class(), + "org/apache/spark/SparkRuntimeException" + ); + } +} diff --git a/native/spark-errors/src/lib.rs b/native/spark-errors/src/lib.rs new file mode 100644 index 0000000000..9319d7347f --- /dev/null +++ b/native/spark-errors/src/lib.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod error; +mod query_context; + +pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; +pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; diff --git a/native/spark-errors/src/query_context.rs b/native/spark-errors/src/query_context.rs new file mode 100644 index 0000000000..10ecef6550 --- /dev/null +++ b/native/spark-errors/src/query_context.rs @@ -0,0 +1,403 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Query execution context for error reporting +//! +//! This module provides QueryContext which mirrors Spark's SQLQueryContext +//! for providing SQL text, line/position information, and error location +//! pointers in exception messages. + +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +/// Based on Spark's SQLQueryContext for error reporting. +/// +/// Contains information about where an error occurred in a SQL query, +/// including the full SQL text, line/column positions, and object context. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct QueryContext { + /// Full SQL query text + #[serde(rename = "sqlText")] + pub sql_text: Arc, + + /// Start offset in SQL text (0-based, character index) + #[serde(rename = "startIndex")] + pub start_index: i32, + + /// Stop offset in SQL text (0-based, character index, inclusive) + #[serde(rename = "stopIndex")] + pub stop_index: i32, + + /// Object type (e.g., "VIEW", "Project", "Filter") + #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] + pub object_type: Option, + + /// Object name (e.g., view name, column name) + #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] + pub object_name: Option, + + /// Line number in SQL query (1-based) + pub line: i32, + + /// Column position within the line (0-based) + #[serde(rename = "startPosition")] + pub start_position: i32, +} + +impl QueryContext { + #[allow(clippy::too_many_arguments)] + pub fn new( + sql_text: String, + start_index: i32, + stop_index: i32, + object_type: Option, + object_name: Option, + line: i32, + start_position: i32, + ) -> Self { + Self { + sql_text: Arc::new(sql_text), + start_index, + stop_index, + object_type, + object_name, + line, + start_position, + } + } + + /// Convert a character index to a byte offset in the SQL text. + /// Returns None if the character index is out of range. + fn char_index_to_byte_offset(&self, char_index: usize) -> Option { + self.sql_text + .char_indices() + .nth(char_index) + .map(|(byte_offset, _)| byte_offset) + } + + /// Generate a summary string showing SQL fragment with error location. + /// (From SQLQueryContext.summary) + /// + /// Format example: + /// ```text + /// == SQL of VIEW v1 (line 1, position 8) == + /// SELECT a/b FROM t + /// ^^^ + /// ``` + pub fn format_summary(&self) -> String { + let start_char = self.start_index.max(0) as usize; + // stop_index is inclusive; fragment covers [start, stop] + let stop_char = (self.stop_index + 1).max(0) as usize; + + let fragment = match ( + self.char_index_to_byte_offset(start_char), + // stop_char may equal sql_text.chars().count() (one past the end) + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], + _ => "", + }; + + // Build the header line + let mut summary = String::from("== SQL"); + + if let Some(obj_type) = &self.object_type { + if !obj_type.is_empty() { + summary.push_str(" of "); + summary.push_str(obj_type); + + if let Some(obj_name) = &self.object_name { + if !obj_name.is_empty() { + summary.push(' '); + summary.push_str(obj_name); + } + } + } + } + + summary.push_str(&format!( + " (line {}, position {}) ==\n", + self.line, + self.start_position + 1 // Convert 0-based to 1-based for display + )); + + // Add the SQL text with fragment highlighted + summary.push_str(&self.sql_text); + summary.push('\n'); + + // Add caret pointer + let caret_position = self.start_position.max(0) as usize; + summary.push_str(&" ".repeat(caret_position)); + // fragment.chars().count() gives the correct display width for non-ASCII + summary.push_str(&"^".repeat(fragment.chars().count().max(1))); + + summary + } + + /// Returns the SQL fragment that caused the error. + #[cfg(test)] + fn fragment(&self) -> String { + let start_char = self.start_index.max(0) as usize; + let stop_char = (self.stop_index + 1).max(0) as usize; + + match ( + self.char_index_to_byte_offset(start_char), + self.char_index_to_byte_offset(stop_char).or_else(|| { + if stop_char == self.sql_text.chars().count() { + Some(self.sql_text.len()) + } else { + None + } + }), + ) { + (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), + _ => String::new(), + } + } +} + +use std::collections::HashMap; +use std::sync::RwLock; + +/// Map that stores QueryContext information for expressions during execution. +/// +/// This map is populated during plan deserialization and accessed +/// during error creation to attach SQL context to exceptions. +#[derive(Debug)] +pub struct QueryContextMap { + /// Map from expression ID to QueryContext + contexts: RwLock>>, +} + +impl QueryContextMap { + pub fn new() -> Self { + Self { + contexts: RwLock::new(HashMap::new()), + } + } + + /// Register a QueryContext for an expression ID. + /// + /// If the expression ID already exists, it will be replaced. + /// + /// # Arguments + /// * `expr_id` - Unique expression identifier from protobuf + /// * `context` - QueryContext containing SQL text and position info + pub fn register(&self, expr_id: u64, context: QueryContext) { + let mut contexts = self.contexts.write().unwrap(); + contexts.insert(expr_id, Arc::new(context)); + } + + /// Get the QueryContext for an expression ID. + /// + /// Returns None if no context is registered for this expression. + /// + /// # Arguments + /// * `expr_id` - Expression identifier to look up + pub fn get(&self, expr_id: u64) -> Option> { + let contexts = self.contexts.read().unwrap(); + contexts.get(&expr_id).cloned() + } + + /// Clear all registered contexts. + /// + /// This is typically called after plan execution completes to free memory. + pub fn clear(&self) { + let mut contexts = self.contexts.write().unwrap(); + contexts.clear(); + } + + /// Return the number of registered contexts (for debugging/testing) + pub fn len(&self) -> usize { + let contexts = self.contexts.read().unwrap(); + contexts.len() + } + + /// Check if the map is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Default for QueryContextMap { + fn default() -> Self { + Self::new() + } +} + +/// Create a new session-scoped QueryContextMap. +/// +/// This should be called once per SessionContext during plan creation +/// and passed to expressions that need query context for error reporting. +pub fn create_query_context_map() -> Arc { + Arc::new(QueryContextMap::new()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query_context_creation() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); + assert_eq!(ctx.start_index, 7); + assert_eq!(ctx.stop_index, 9); + assert_eq!(ctx.object_type, Some("Divide".to_string())); + assert_eq!(ctx.object_name, Some("a/b".to_string())); + assert_eq!(ctx.line, 1); + assert_eq!(ctx.start_position, 7); + } + + #[test] + fn test_query_context_serialization() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("Divide".to_string()), + Some("a/b".to_string()), + 1, + 7, + ); + + let json = serde_json::to_string(&ctx).unwrap(); + let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); + + assert_eq!(ctx, deserialized); + } + + #[test] + fn test_format_summary() { + let ctx = QueryContext::new( + "SELECT a/b FROM t".to_string(), + 7, + 9, + Some("VIEW".to_string()), + Some("v1".to_string()), + 1, + 7, + ); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + assert!(summary.contains("^^^")); // Three carets for "a/b" + } + + #[test] + fn test_format_summary_without_object() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let summary = ctx.format_summary(); + + assert!(summary.contains("== SQL (line 1, position 8) ==")); + assert!(summary.contains("SELECT a/b FROM t")); + } + + #[test] + fn test_fragment() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + assert_eq!(ctx.fragment(), "a/b"); + } + + #[test] + fn test_arc_string_sharing() { + let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let ctx2 = ctx1.clone(); + + // Arc should share the same allocation + assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); + } + + #[test] + fn test_json_with_optional_fields() { + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + let json = serde_json::to_string(&ctx).unwrap(); + + // Should not serialize objectType and objectName when None + assert!(!json.contains("objectType")); + assert!(!json.contains("objectName")); + } + + #[test] + fn test_map_register_and_get() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx.clone()); + + let retrieved = map.get(1).unwrap(); + assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); + assert_eq!(retrieved.start_index, 7); + } + + #[test] + fn test_map_get_nonexistent() { + let map = QueryContextMap::new(); + assert!(map.get(999).is_none()); + } + + #[test] + fn test_map_clear() { + let map = QueryContextMap::new(); + + let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); + + map.register(1, ctx); + assert_eq!(map.len(), 1); + + map.clear(); + assert_eq!(map.len(), 0); + assert!(map.is_empty()); + } + + // Verify that fragment() and format_summary() correctly handle SQL text that + // contains multi-byte characters + + #[test] + fn test_fragment_non_ascii_accented() { + // "é" is a 2-byte UTF-8 sequence (U+00E9). + // SQL: "SELECT café FROM t" + // 0123456789... + // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. + // start_index=7, stop_index=10 should yield "café" + let sql = "SELECT café FROM t".to_string(); + let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); + assert_eq!(ctx.fragment(), "café"); + } +} diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 9f08e480f2..990553a60f 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -36,6 +36,7 @@ regex = { workspace = true } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" thiserror = { workspace = true } +datafusion-comet-spark-errors = { workspace = true } futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs index bd7efe1638..9d0e765b92 100644 --- a/native/spark-expr/src/error.rs +++ b/native/spark-expr/src/error.rs @@ -15,845 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::error::ArrowError; -use datafusion::common::DataFusionError; -use std::sync::Arc; - -#[derive(thiserror::Error, Debug, Clone)] -pub enum SparkError { - // This list was generated from the Spark code. Many of the exceptions are not yet used by Comet - #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - because it is malformed. Correct the value as per the syntax, or change its target type. \ - Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastInvalidValue { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] - NumericValueOutOfRange { - value: String, - precision: u8, - scale: i8, - }, - - #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")] - NumericOutOfRange { value: String }, - - #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ - due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ - set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - CastOverFlow { - value: String, - from_type: String, - to_type: String, - }, - - #[error("[CANNOT_PARSE_DECIMAL] Cannot parse decimal.")] - CannotParseDecimal, - - #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - ArithmeticOverflow { from_type: String }, - - #[error("[ARITHMETIC_OVERFLOW] Overflow in integral divide. Use `try_divide` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntegralDivideOverflow, - - #[error("[ARITHMETIC_OVERFLOW] Overflow in sum of decimals. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - DecimalSumOverflow, - - #[error("[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - DivideByZero, - - #[error("[REMAINDER_BY_ZERO] Division by zero. Use `try_remainder` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - RemainderByZero, - - #[error("[INTERVAL_DIVIDED_BY_ZERO] Divide by zero in interval arithmetic.")] - IntervalDividedByZero, - - #[error("[BINARY_ARITHMETIC_OVERFLOW] {value1} {symbol} {value2} caused overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - BinaryArithmeticOverflow { - value1: String, - symbol: String, - value2: String, - function_name: String, - }, - - #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITH_SUGGESTION] Interval arithmetic overflow. Use `{function_name}` to tolerate overflow and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntervalArithmeticOverflowWithSuggestion { function_name: String }, - - #[error("[INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION] Interval arithmetic overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - IntervalArithmeticOverflowWithoutSuggestion, - - #[error("[DATETIME_OVERFLOW] Datetime arithmetic overflow.")] - DatetimeOverflow, - - #[error("[INVALID_ARRAY_INDEX] The index {index_value} is out of bounds. The array has {array_size} elements. Use the SQL function get() to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidArrayIndex { index_value: i32, array_size: i32 }, - - #[error("[INVALID_ARRAY_INDEX_IN_ELEMENT_AT] The index {index_value} is out of bounds. The array has {array_size} elements. Use try_element_at to tolerate accessing element at invalid index and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidElementAtIndex { index_value: i32, array_size: i32 }, - - #[error("[INVALID_BITMAP_POSITION] The bit position {bit_position} is out of bounds. The bitmap has {bitmap_num_bytes} bytes ({bitmap_num_bits} bits).")] - InvalidBitmapPosition { - bit_position: i64, - bitmap_num_bytes: i64, - bitmap_num_bits: i64, - }, - - #[error("[INVALID_INDEX_OF_ZERO] The index 0 is invalid. An index shall be either < 0 or > 0 (the first element has index 1).")] - InvalidIndexOfZero, - - #[error("[DUPLICATED_MAP_KEY] Cannot create map with duplicate keys: {key}.")] - DuplicatedMapKey { key: String }, - - #[error("[NULL_MAP_KEY] Cannot use null as map key.")] - NullMapKey, - - #[error("[MAP_KEY_VALUE_DIFF_SIZES] The key array and value array of a map must have the same length.")] - MapKeyValueDiffSizes, - - #[error("[EXCEED_LIMIT_LENGTH] Cannot create a map with {size} elements which exceeds the limit {max_size}.")] - ExceedMapSizeLimit { size: i32, max_size: i32 }, - - #[error("[COLLECTION_SIZE_LIMIT_EXCEEDED] Cannot create array with {num_elements} elements which exceeds the limit {max_elements}.")] - CollectionSizeLimitExceeded { - num_elements: i64, - max_elements: i64, - }, - - #[error("[NOT_NULL_ASSERT_VIOLATION] The field `{field_name}` cannot be null.")] - NotNullAssertViolation { field_name: String }, - - #[error("[VALUE_IS_NULL] The value of field `{field_name}` at row {row_index} is null.")] - ValueIsNull { field_name: String, row_index: i32 }, - - #[error("[CANNOT_PARSE_TIMESTAMP] Cannot parse timestamp: {message}. Try using `{suggested_func}` instead.")] - CannotParseTimestamp { - message: String, - suggested_func: String, - }, - - #[error("[INVALID_FRACTION_OF_SECOND] The fraction of second {value} is invalid. Valid values are in the range [0, 60]. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] - InvalidFractionOfSecond { value: f64 }, - - #[error("[INVALID_UTF8_STRING] Invalid UTF-8 string: {hex_string}.")] - InvalidUtf8String { hex_string: String }, - - #[error("[UNEXPECTED_POSITIVE_VALUE] The {parameter_name} parameter must be less than or equal to 0. The actual value is {actual_value}.")] - UnexpectedPositiveValue { - parameter_name: String, - actual_value: i32, - }, - - #[error("[UNEXPECTED_NEGATIVE_VALUE] The {parameter_name} parameter must be greater than or equal to 0. The actual value is {actual_value}.")] - UnexpectedNegativeValue { - parameter_name: String, - actual_value: i32, - }, - - #[error("[INVALID_PARAMETER_VALUE] Invalid regex group index {group_index} in function `{function_name}`. Group count is {group_count}.")] - InvalidRegexGroupIndex { - function_name: String, - group_count: i32, - group_index: i32, - }, - - #[error("[DATATYPE_CANNOT_ORDER] Cannot order by type: {data_type}.")] - DatatypeCannotOrder { data_type: String }, - - #[error("[SCALAR_SUBQUERY_TOO_MANY_ROWS] Scalar subquery returned more than one row.")] - ScalarSubqueryTooManyRows, - - #[error("ArrowError: {0}.")] - Arrow(Arc), - - #[error("InternalError: {0}.")] - Internal(String), -} - -impl SparkError { - /// Serialize this error to JSON format for JNI transfer - pub fn to_json(&self) -> String { - let error_class = self.error_class().unwrap_or(""); - - // Create a JSON structure with errorType, errorClass, and params - match serde_json::to_string(&serde_json::json!({ - "errorType": self.error_type_name(), - "errorClass": error_class, - "params": self.params_as_json(), - })) { - Ok(json) => json, - Err(e) => { - // Fallback if serialization fails - format!( - "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", - e - ) - } - } - } - - /// Get the error type name for JSON serialization - fn error_type_name(&self) -> &'static str { - match self { - SparkError::CastInvalidValue { .. } => "CastInvalidValue", - SparkError::NumericValueOutOfRange { .. } => "NumericValueOutOfRange", - SparkError::NumericOutOfRange { .. } => "NumericOutOfRange", - SparkError::CastOverFlow { .. } => "CastOverFlow", - SparkError::CannotParseDecimal => "CannotParseDecimal", - SparkError::ArithmeticOverflow { .. } => "ArithmeticOverflow", - SparkError::IntegralDivideOverflow => "IntegralDivideOverflow", - SparkError::DecimalSumOverflow => "DecimalSumOverflow", - SparkError::DivideByZero => "DivideByZero", - SparkError::RemainderByZero => "RemainderByZero", - SparkError::IntervalDividedByZero => "IntervalDividedByZero", - SparkError::BinaryArithmeticOverflow { .. } => "BinaryArithmeticOverflow", - SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { - "IntervalArithmeticOverflowWithSuggestion" - } - SparkError::IntervalArithmeticOverflowWithoutSuggestion => { - "IntervalArithmeticOverflowWithoutSuggestion" - } - SparkError::DatetimeOverflow => "DatetimeOverflow", - SparkError::InvalidArrayIndex { .. } => "InvalidArrayIndex", - SparkError::InvalidElementAtIndex { .. } => "InvalidElementAtIndex", - SparkError::InvalidBitmapPosition { .. } => "InvalidBitmapPosition", - SparkError::InvalidIndexOfZero => "InvalidIndexOfZero", - SparkError::DuplicatedMapKey { .. } => "DuplicatedMapKey", - SparkError::NullMapKey => "NullMapKey", - SparkError::MapKeyValueDiffSizes => "MapKeyValueDiffSizes", - SparkError::ExceedMapSizeLimit { .. } => "ExceedMapSizeLimit", - SparkError::CollectionSizeLimitExceeded { .. } => "CollectionSizeLimitExceeded", - SparkError::NotNullAssertViolation { .. } => "NotNullAssertViolation", - SparkError::ValueIsNull { .. } => "ValueIsNull", - SparkError::CannotParseTimestamp { .. } => "CannotParseTimestamp", - SparkError::InvalidFractionOfSecond { .. } => "InvalidFractionOfSecond", - SparkError::InvalidUtf8String { .. } => "InvalidUtf8String", - SparkError::UnexpectedPositiveValue { .. } => "UnexpectedPositiveValue", - SparkError::UnexpectedNegativeValue { .. } => "UnexpectedNegativeValue", - SparkError::InvalidRegexGroupIndex { .. } => "InvalidRegexGroupIndex", - SparkError::DatatypeCannotOrder { .. } => "DatatypeCannotOrder", - SparkError::ScalarSubqueryTooManyRows => "ScalarSubqueryTooManyRows", - SparkError::Arrow(_) => "Arrow", - SparkError::Internal(_) => "Internal", - } - } - - /// Extract parameters as JSON value - fn params_as_json(&self) -> serde_json::Value { - match self { - SparkError::CastInvalidValue { - value, - from_type, - to_type, - } => { - serde_json::json!({ - "value": value, - "fromType": from_type, - "toType": to_type, - }) - } - SparkError::NumericValueOutOfRange { - value, - precision, - scale, - } => { - serde_json::json!({ - "value": value, - "precision": precision, - "scale": scale, - }) - } - SparkError::NumericOutOfRange { value } => { - serde_json::json!({ - "value": value, - }) - } - SparkError::CastOverFlow { - value, - from_type, - to_type, - } => { - serde_json::json!({ - "value": value, - "fromType": from_type, - "toType": to_type, - }) - } - SparkError::ArithmeticOverflow { from_type } => { - serde_json::json!({ - "fromType": from_type, - }) - } - SparkError::BinaryArithmeticOverflow { - value1, - symbol, - value2, - function_name, - } => { - serde_json::json!({ - "value1": value1, - "symbol": symbol, - "value2": value2, - "functionName": function_name, - }) - } - SparkError::IntervalArithmeticOverflowWithSuggestion { function_name } => { - serde_json::json!({ - "functionName": function_name, - }) - } - SparkError::InvalidArrayIndex { - index_value, - array_size, - } => { - serde_json::json!({ - "indexValue": index_value, - "arraySize": array_size, - }) - } - SparkError::InvalidElementAtIndex { - index_value, - array_size, - } => { - serde_json::json!({ - "indexValue": index_value, - "arraySize": array_size, - }) - } - SparkError::InvalidBitmapPosition { - bit_position, - bitmap_num_bytes, - bitmap_num_bits, - } => { - serde_json::json!({ - "bitPosition": bit_position, - "bitmapNumBytes": bitmap_num_bytes, - "bitmapNumBits": bitmap_num_bits, - }) - } - SparkError::DuplicatedMapKey { key } => { - serde_json::json!({ - "key": key, - }) - } - SparkError::ExceedMapSizeLimit { size, max_size } => { - serde_json::json!({ - "size": size, - "maxSize": max_size, - }) - } - SparkError::CollectionSizeLimitExceeded { - num_elements, - max_elements, - } => { - serde_json::json!({ - "numElements": num_elements, - "maxElements": max_elements, - }) - } - SparkError::NotNullAssertViolation { field_name } => { - serde_json::json!({ - "fieldName": field_name, - }) - } - SparkError::ValueIsNull { - field_name, - row_index, - } => { - serde_json::json!({ - "fieldName": field_name, - "rowIndex": row_index, - }) - } - SparkError::CannotParseTimestamp { - message, - suggested_func, - } => { - serde_json::json!({ - "message": message, - "suggestedFunc": suggested_func, - }) - } - SparkError::InvalidFractionOfSecond { value } => { - serde_json::json!({ - "value": value, - }) - } - SparkError::InvalidUtf8String { hex_string } => { - serde_json::json!({ - "hexString": hex_string, - }) - } - SparkError::UnexpectedPositiveValue { - parameter_name, - actual_value, - } => { - serde_json::json!({ - "parameterName": parameter_name, - "actualValue": actual_value, - }) - } - SparkError::UnexpectedNegativeValue { - parameter_name, - actual_value, - } => { - serde_json::json!({ - "parameterName": parameter_name, - "actualValue": actual_value, - }) - } - SparkError::InvalidRegexGroupIndex { - function_name, - group_count, - group_index, - } => { - serde_json::json!({ - "functionName": function_name, - "groupCount": group_count, - "groupIndex": group_index, - }) - } - SparkError::DatatypeCannotOrder { data_type } => { - serde_json::json!({ - "dataType": data_type, - }) - } - SparkError::Arrow(e) => { - serde_json::json!({ - "message": e.to_string(), - }) - } - SparkError::Internal(msg) => { - serde_json::json!({ - "message": msg, - }) - } - // Simple errors with no parameters - _ => serde_json::json!({}), - } - } - - /// Returns the appropriate Spark exception class for this error - pub fn exception_class(&self) -> &'static str { - match self { - // ArithmeticException - SparkError::DivideByZero - | SparkError::RemainderByZero - | SparkError::IntervalDividedByZero - | SparkError::NumericValueOutOfRange { .. } - | SparkError::NumericOutOfRange { .. } // Comet-specific extension - | SparkError::ArithmeticOverflow { .. } - | SparkError::IntegralDivideOverflow - | SparkError::DecimalSumOverflow - | SparkError::BinaryArithmeticOverflow { .. } - | SparkError::IntervalArithmeticOverflowWithSuggestion { .. } - | SparkError::IntervalArithmeticOverflowWithoutSuggestion - | SparkError::DatetimeOverflow => "org/apache/spark/SparkArithmeticException", - - // CastOverflow gets special handling with CastOverflowException - SparkError::CastOverFlow { .. } => "org/apache/spark/sql/comet/CastOverflowException", - - // NumberFormatException (for cast invalid input errors) - SparkError::CastInvalidValue { .. } => "org/apache/spark/SparkNumberFormatException", - - // ArrayIndexOutOfBoundsException - SparkError::InvalidArrayIndex { .. } - | SparkError::InvalidElementAtIndex { .. } - | SparkError::InvalidBitmapPosition { .. } - | SparkError::InvalidIndexOfZero => "org/apache/spark/SparkArrayIndexOutOfBoundsException", - - // RuntimeException - SparkError::CannotParseDecimal - | SparkError::DuplicatedMapKey { .. } - | SparkError::NullMapKey - | SparkError::MapKeyValueDiffSizes - | SparkError::ExceedMapSizeLimit { .. } - | SparkError::CollectionSizeLimitExceeded { .. } - | SparkError::NotNullAssertViolation { .. } - | SparkError::ValueIsNull { .. } // Comet-specific extension - | SparkError::UnexpectedPositiveValue { .. } - | SparkError::UnexpectedNegativeValue { .. } - | SparkError::InvalidRegexGroupIndex { .. } - | SparkError::ScalarSubqueryTooManyRows => "org/apache/spark/SparkRuntimeException", - - // DateTimeException - SparkError::CannotParseTimestamp { .. } - | SparkError::InvalidFractionOfSecond { .. } => "org/apache/spark/SparkDateTimeException", - - // IllegalArgumentException - SparkError::DatatypeCannotOrder { .. } - | SparkError::InvalidUtf8String { .. } => "org/apache/spark/SparkIllegalArgumentException", - - // Generic errors - SparkError::Arrow(_) | SparkError::Internal(_) => "org/apache/spark/SparkException", - } - } - - /// Returns the Spark error class code for this error - fn error_class(&self) -> Option<&'static str> { - match self { - // Cast errors - SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"), - SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"), - SparkError::NumericValueOutOfRange { .. } => { - Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION") - } - SparkError::NumericOutOfRange { .. } => Some("NUMERIC_OUT_OF_SUPPORTED_RANGE"), - SparkError::CannotParseDecimal => Some("CANNOT_PARSE_DECIMAL"), - - // Arithmetic errors - SparkError::DivideByZero => Some("DIVIDE_BY_ZERO"), - SparkError::RemainderByZero => Some("REMAINDER_BY_ZERO"), - SparkError::IntervalDividedByZero => Some("INTERVAL_DIVIDED_BY_ZERO"), - SparkError::ArithmeticOverflow { .. } => Some("ARITHMETIC_OVERFLOW"), - SparkError::IntegralDivideOverflow => Some("ARITHMETIC_OVERFLOW"), - SparkError::DecimalSumOverflow => Some("ARITHMETIC_OVERFLOW"), - SparkError::BinaryArithmeticOverflow { .. } => Some("BINARY_ARITHMETIC_OVERFLOW"), - SparkError::IntervalArithmeticOverflowWithSuggestion { .. } => { - Some("INTERVAL_ARITHMETIC_OVERFLOW") - } - SparkError::IntervalArithmeticOverflowWithoutSuggestion => { - Some("INTERVAL_ARITHMETIC_OVERFLOW") - } - SparkError::DatetimeOverflow => Some("DATETIME_OVERFLOW"), - - // Array index errors - SparkError::InvalidArrayIndex { .. } => Some("INVALID_ARRAY_INDEX"), - SparkError::InvalidElementAtIndex { .. } => Some("INVALID_ARRAY_INDEX_IN_ELEMENT_AT"), - SparkError::InvalidBitmapPosition { .. } => Some("INVALID_BITMAP_POSITION"), - SparkError::InvalidIndexOfZero => Some("INVALID_INDEX_OF_ZERO"), - - // Map/Collection errors - SparkError::DuplicatedMapKey { .. } => Some("DUPLICATED_MAP_KEY"), - SparkError::NullMapKey => Some("NULL_MAP_KEY"), - SparkError::MapKeyValueDiffSizes => Some("MAP_KEY_VALUE_DIFF_SIZES"), - SparkError::ExceedMapSizeLimit { .. } => Some("EXCEED_LIMIT_LENGTH"), - SparkError::CollectionSizeLimitExceeded { .. } => { - Some("COLLECTION_SIZE_LIMIT_EXCEEDED") - } - - // Null validation errors - SparkError::NotNullAssertViolation { .. } => Some("NOT_NULL_ASSERT_VIOLATION"), - SparkError::ValueIsNull { .. } => Some("VALUE_IS_NULL"), - - // DateTime errors - SparkError::CannotParseTimestamp { .. } => Some("CANNOT_PARSE_TIMESTAMP"), - SparkError::InvalidFractionOfSecond { .. } => Some("INVALID_FRACTION_OF_SECOND"), - - // String/UTF8 errors - SparkError::InvalidUtf8String { .. } => Some("INVALID_UTF8_STRING"), - - // Function parameter errors - SparkError::UnexpectedPositiveValue { .. } => Some("UNEXPECTED_POSITIVE_VALUE"), - SparkError::UnexpectedNegativeValue { .. } => Some("UNEXPECTED_NEGATIVE_VALUE"), - - // Regex errors - SparkError::InvalidRegexGroupIndex { .. } => Some("INVALID_PARAMETER_VALUE"), - - // Unsupported operation errors - SparkError::DatatypeCannotOrder { .. } => Some("DATATYPE_CANNOT_ORDER"), - - // Subquery errors - SparkError::ScalarSubqueryTooManyRows => Some("SCALAR_SUBQUERY_TOO_MANY_ROWS"), - - // Generic errors (no error class) - SparkError::Arrow(_) | SparkError::Internal(_) => None, - } - } -} - -pub type SparkResult = Result; - -/// Convert decimal overflow to SparkError::NumericValueOutOfRange. -pub fn decimal_overflow_error(value: i128, precision: u8, scale: i8) -> SparkError { - SparkError::NumericValueOutOfRange { - value: value.to_string(), - precision, - scale, - } -} - -/// Wrapper that adds QueryContext to SparkError -/// -/// This allows attaching SQL context information (query text, line/position, object name) to errors -#[derive(Debug, Clone)] -pub struct SparkErrorWithContext { - /// The underlying SparkError - pub error: SparkError, - /// Optional QueryContext for SQL location information - pub context: Option>, -} - -impl SparkErrorWithContext { - /// Create a SparkErrorWithContext without context - pub fn new(error: SparkError) -> Self { - Self { - error, - context: None, - } - } - - /// Create a SparkErrorWithContext with QueryContext - pub fn with_context(error: SparkError, context: Arc) -> Self { - Self { - error, - context: Some(context), - } - } - - /// Serialize to JSON including optional context field - /// - /// JSON structure: - /// ```json - /// { - /// "errorType": "DivideByZero", - /// "errorClass": "DIVIDE_BY_ZERO", - /// "params": {}, - /// "context": { - /// "sqlText": "SELECT a/b FROM t", - /// "startIndex": 7, - /// "stopIndex": 9, - /// "line": 1, - /// "startPosition": 7 - /// }, - /// "summary": "== SQL (line 1, position 8) ==\n..." - /// } - /// ``` - pub fn to_json(&self) -> String { - let mut json_obj = serde_json::json!({ - "errorType": self.error.error_type_name(), - "errorClass": self.error.error_class().unwrap_or(""), - "params": self.error.params_as_json(), - }); - - if let Some(ctx) = &self.context { - // Serialize context fields - json_obj["context"] = serde_json::json!({ - "sqlText": ctx.sql_text.as_str(), - "startIndex": ctx.start_index, - "stopIndex": ctx.stop_index, - "objectType": ctx.object_type, - "objectName": ctx.object_name, - "line": ctx.line, - "startPosition": ctx.start_position, - }); - - // Add formatted summary - json_obj["summary"] = serde_json::json!(ctx.format_summary()); - } - - serde_json::to_string(&json_obj).unwrap_or_else(|e| { - format!( - "{{\"errorType\":\"SerializationError\",\"message\":\"{}\"}}", - e - ) - }) - } -} - -impl std::fmt::Display for SparkErrorWithContext { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.error)?; - if let Some(ctx) = &self.context { - write!(f, "\n{}", ctx.format_summary())?; - } - Ok(()) - } -} - -impl std::error::Error for SparkErrorWithContext {} - -impl From for SparkErrorWithContext { - fn from(error: SparkError) -> Self { - SparkErrorWithContext::new(error) - } -} - -impl From for DataFusionError { - fn from(value: SparkErrorWithContext) -> Self { - DataFusionError::External(Box::new(value)) - } -} - -impl From for SparkError { - fn from(value: ArrowError) -> Self { - SparkError::Arrow(Arc::new(value)) - } -} - -impl From for DataFusionError { - fn from(value: SparkError) -> Self { - DataFusionError::External(Box::new(value)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_divide_by_zero_json() { - let error = SparkError::DivideByZero; - let json = error.to_json(); - - assert!(json.contains("\"errorType\":\"DivideByZero\"")); - assert!(json.contains("\"errorClass\":\"DIVIDE_BY_ZERO\"")); - - // Verify it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "DivideByZero"); - assert_eq!(parsed["errorClass"], "DIVIDE_BY_ZERO"); - } - - #[test] - fn test_remainder_by_zero_json() { - let error = SparkError::RemainderByZero; - let json = error.to_json(); - - assert!(json.contains("\"errorType\":\"RemainderByZero\"")); - assert!(json.contains("\"errorClass\":\"REMAINDER_BY_ZERO\"")); - } - - #[test] - fn test_binary_overflow_json() { - let error = SparkError::BinaryArithmeticOverflow { - value1: "32767".to_string(), - symbol: "+".to_string(), - value2: "1".to_string(), - function_name: "try_add".to_string(), - }; - let json = error.to_json(); - - // Verify it's valid JSON - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "BinaryArithmeticOverflow"); - assert_eq!(parsed["errorClass"], "BINARY_ARITHMETIC_OVERFLOW"); - assert_eq!(parsed["params"]["value1"], "32767"); - assert_eq!(parsed["params"]["symbol"], "+"); - assert_eq!(parsed["params"]["value2"], "1"); - assert_eq!(parsed["params"]["functionName"], "try_add"); - } - - #[test] - fn test_invalid_array_index_json() { - let error = SparkError::InvalidArrayIndex { - index_value: 10, - array_size: 3, - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "InvalidArrayIndex"); - assert_eq!(parsed["errorClass"], "INVALID_ARRAY_INDEX"); - assert_eq!(parsed["params"]["indexValue"], 10); - assert_eq!(parsed["params"]["arraySize"], 3); - } - - #[test] - fn test_numeric_value_out_of_range_json() { - let error = SparkError::NumericValueOutOfRange { - value: "999.99".to_string(), - precision: 5, - scale: 2, - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "NumericValueOutOfRange"); - assert_eq!( - parsed["errorClass"], - "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION" - ); - assert_eq!(parsed["params"]["value"], "999.99"); - assert_eq!(parsed["params"]["precision"], 5); - assert_eq!(parsed["params"]["scale"], 2); - } - - #[test] - fn test_cast_invalid_value_json() { - let error = SparkError::CastInvalidValue { - value: "abc".to_string(), - from_type: "STRING".to_string(), - to_type: "INT".to_string(), - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "CastInvalidValue"); - assert_eq!(parsed["errorClass"], "CAST_INVALID_INPUT"); - assert_eq!(parsed["params"]["value"], "abc"); - assert_eq!(parsed["params"]["fromType"], "STRING"); - assert_eq!(parsed["params"]["toType"], "INT"); - } - - #[test] - fn test_duplicated_map_key_json() { - let error = SparkError::DuplicatedMapKey { - key: "duplicate_key".to_string(), - }; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "DuplicatedMapKey"); - assert_eq!(parsed["errorClass"], "DUPLICATED_MAP_KEY"); - assert_eq!(parsed["params"]["key"], "duplicate_key"); - } - - #[test] - fn test_null_map_key_json() { - let error = SparkError::NullMapKey; - let json = error.to_json(); - - let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed["errorType"], "NullMapKey"); - assert_eq!(parsed["errorClass"], "NULL_MAP_KEY"); - // Params should be an empty object - assert_eq!(parsed["params"], serde_json::json!({})); - } - - #[test] - fn test_error_class_mapping() { - // Test that error_class() returns the correct error class - assert_eq!( - SparkError::DivideByZero.error_class(), - Some("DIVIDE_BY_ZERO") - ); - assert_eq!( - SparkError::RemainderByZero.error_class(), - Some("REMAINDER_BY_ZERO") - ); - assert_eq!( - SparkError::InvalidArrayIndex { - index_value: 0, - array_size: 0 - } - .error_class(), - Some("INVALID_ARRAY_INDEX") - ); - assert_eq!(SparkError::NullMapKey.error_class(), Some("NULL_MAP_KEY")); - } - - #[test] - fn test_exception_class_mapping() { - // Test that exception_class() returns the correct Java exception class - assert_eq!( - SparkError::DivideByZero.exception_class(), - "org/apache/spark/SparkArithmeticException" - ); - assert_eq!( - SparkError::InvalidArrayIndex { - index_value: 0, - array_size: 0 - } - .exception_class(), - "org/apache/spark/SparkArrayIndexOutOfBoundsException" - ); - assert_eq!( - SparkError::NullMapKey.exception_class(), - "org/apache/spark/SparkRuntimeException" - ); - } -} +// Re-export all error types from the spark-errors crate +pub use datafusion_comet_spark_errors::{ + decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult, +}; diff --git a/native/spark-expr/src/query_context.rs b/native/spark-expr/src/query_context.rs index 10ecef6550..431ba54867 100644 --- a/native/spark-expr/src/query_context.rs +++ b/native/spark-expr/src/query_context.rs @@ -15,389 +15,5 @@ // specific language governing permissions and limitations // under the License. -//! Query execution context for error reporting -//! -//! This module provides QueryContext which mirrors Spark's SQLQueryContext -//! for providing SQL text, line/position information, and error location -//! pointers in exception messages. - -use serde::{Deserialize, Serialize}; -use std::sync::Arc; - -/// Based on Spark's SQLQueryContext for error reporting. -/// -/// Contains information about where an error occurred in a SQL query, -/// including the full SQL text, line/column positions, and object context. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct QueryContext { - /// Full SQL query text - #[serde(rename = "sqlText")] - pub sql_text: Arc, - - /// Start offset in SQL text (0-based, character index) - #[serde(rename = "startIndex")] - pub start_index: i32, - - /// Stop offset in SQL text (0-based, character index, inclusive) - #[serde(rename = "stopIndex")] - pub stop_index: i32, - - /// Object type (e.g., "VIEW", "Project", "Filter") - #[serde(rename = "objectType", skip_serializing_if = "Option::is_none")] - pub object_type: Option, - - /// Object name (e.g., view name, column name) - #[serde(rename = "objectName", skip_serializing_if = "Option::is_none")] - pub object_name: Option, - - /// Line number in SQL query (1-based) - pub line: i32, - - /// Column position within the line (0-based) - #[serde(rename = "startPosition")] - pub start_position: i32, -} - -impl QueryContext { - #[allow(clippy::too_many_arguments)] - pub fn new( - sql_text: String, - start_index: i32, - stop_index: i32, - object_type: Option, - object_name: Option, - line: i32, - start_position: i32, - ) -> Self { - Self { - sql_text: Arc::new(sql_text), - start_index, - stop_index, - object_type, - object_name, - line, - start_position, - } - } - - /// Convert a character index to a byte offset in the SQL text. - /// Returns None if the character index is out of range. - fn char_index_to_byte_offset(&self, char_index: usize) -> Option { - self.sql_text - .char_indices() - .nth(char_index) - .map(|(byte_offset, _)| byte_offset) - } - - /// Generate a summary string showing SQL fragment with error location. - /// (From SQLQueryContext.summary) - /// - /// Format example: - /// ```text - /// == SQL of VIEW v1 (line 1, position 8) == - /// SELECT a/b FROM t - /// ^^^ - /// ``` - pub fn format_summary(&self) -> String { - let start_char = self.start_index.max(0) as usize; - // stop_index is inclusive; fragment covers [start, stop] - let stop_char = (self.stop_index + 1).max(0) as usize; - - let fragment = match ( - self.char_index_to_byte_offset(start_char), - // stop_char may equal sql_text.chars().count() (one past the end) - self.char_index_to_byte_offset(stop_char).or_else(|| { - if stop_char == self.sql_text.chars().count() { - Some(self.sql_text.len()) - } else { - None - } - }), - ) { - (Some(start_byte), Some(stop_byte)) => &self.sql_text[start_byte..stop_byte], - _ => "", - }; - - // Build the header line - let mut summary = String::from("== SQL"); - - if let Some(obj_type) = &self.object_type { - if !obj_type.is_empty() { - summary.push_str(" of "); - summary.push_str(obj_type); - - if let Some(obj_name) = &self.object_name { - if !obj_name.is_empty() { - summary.push(' '); - summary.push_str(obj_name); - } - } - } - } - - summary.push_str(&format!( - " (line {}, position {}) ==\n", - self.line, - self.start_position + 1 // Convert 0-based to 1-based for display - )); - - // Add the SQL text with fragment highlighted - summary.push_str(&self.sql_text); - summary.push('\n'); - - // Add caret pointer - let caret_position = self.start_position.max(0) as usize; - summary.push_str(&" ".repeat(caret_position)); - // fragment.chars().count() gives the correct display width for non-ASCII - summary.push_str(&"^".repeat(fragment.chars().count().max(1))); - - summary - } - - /// Returns the SQL fragment that caused the error. - #[cfg(test)] - fn fragment(&self) -> String { - let start_char = self.start_index.max(0) as usize; - let stop_char = (self.stop_index + 1).max(0) as usize; - - match ( - self.char_index_to_byte_offset(start_char), - self.char_index_to_byte_offset(stop_char).or_else(|| { - if stop_char == self.sql_text.chars().count() { - Some(self.sql_text.len()) - } else { - None - } - }), - ) { - (Some(start_byte), Some(stop_byte)) => self.sql_text[start_byte..stop_byte].to_string(), - _ => String::new(), - } - } -} - -use std::collections::HashMap; -use std::sync::RwLock; - -/// Map that stores QueryContext information for expressions during execution. -/// -/// This map is populated during plan deserialization and accessed -/// during error creation to attach SQL context to exceptions. -#[derive(Debug)] -pub struct QueryContextMap { - /// Map from expression ID to QueryContext - contexts: RwLock>>, -} - -impl QueryContextMap { - pub fn new() -> Self { - Self { - contexts: RwLock::new(HashMap::new()), - } - } - - /// Register a QueryContext for an expression ID. - /// - /// If the expression ID already exists, it will be replaced. - /// - /// # Arguments - /// * `expr_id` - Unique expression identifier from protobuf - /// * `context` - QueryContext containing SQL text and position info - pub fn register(&self, expr_id: u64, context: QueryContext) { - let mut contexts = self.contexts.write().unwrap(); - contexts.insert(expr_id, Arc::new(context)); - } - - /// Get the QueryContext for an expression ID. - /// - /// Returns None if no context is registered for this expression. - /// - /// # Arguments - /// * `expr_id` - Expression identifier to look up - pub fn get(&self, expr_id: u64) -> Option> { - let contexts = self.contexts.read().unwrap(); - contexts.get(&expr_id).cloned() - } - - /// Clear all registered contexts. - /// - /// This is typically called after plan execution completes to free memory. - pub fn clear(&self) { - let mut contexts = self.contexts.write().unwrap(); - contexts.clear(); - } - - /// Return the number of registered contexts (for debugging/testing) - pub fn len(&self) -> usize { - let contexts = self.contexts.read().unwrap(); - contexts.len() - } - - /// Check if the map is empty - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl Default for QueryContextMap { - fn default() -> Self { - Self::new() - } -} - -/// Create a new session-scoped QueryContextMap. -/// -/// This should be called once per SessionContext during plan creation -/// and passed to expressions that need query context for error reporting. -pub fn create_query_context_map() -> Arc { - Arc::new(QueryContextMap::new()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_query_context_creation() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("Divide".to_string()), - Some("a/b".to_string()), - 1, - 7, - ); - - assert_eq!(*ctx.sql_text, "SELECT a/b FROM t"); - assert_eq!(ctx.start_index, 7); - assert_eq!(ctx.stop_index, 9); - assert_eq!(ctx.object_type, Some("Divide".to_string())); - assert_eq!(ctx.object_name, Some("a/b".to_string())); - assert_eq!(ctx.line, 1); - assert_eq!(ctx.start_position, 7); - } - - #[test] - fn test_query_context_serialization() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("Divide".to_string()), - Some("a/b".to_string()), - 1, - 7, - ); - - let json = serde_json::to_string(&ctx).unwrap(); - let deserialized: QueryContext = serde_json::from_str(&json).unwrap(); - - assert_eq!(ctx, deserialized); - } - - #[test] - fn test_format_summary() { - let ctx = QueryContext::new( - "SELECT a/b FROM t".to_string(), - 7, - 9, - Some("VIEW".to_string()), - Some("v1".to_string()), - 1, - 7, - ); - - let summary = ctx.format_summary(); - - assert!(summary.contains("== SQL of VIEW v1 (line 1, position 8) ==")); - assert!(summary.contains("SELECT a/b FROM t")); - assert!(summary.contains("^^^")); // Three carets for "a/b" - } - - #[test] - fn test_format_summary_without_object() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let summary = ctx.format_summary(); - - assert!(summary.contains("== SQL (line 1, position 8) ==")); - assert!(summary.contains("SELECT a/b FROM t")); - } - - #[test] - fn test_fragment() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - assert_eq!(ctx.fragment(), "a/b"); - } - - #[test] - fn test_arc_string_sharing() { - let ctx1 = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let ctx2 = ctx1.clone(); - - // Arc should share the same allocation - assert!(Arc::ptr_eq(&ctx1.sql_text, &ctx2.sql_text)); - } - - #[test] - fn test_json_with_optional_fields() { - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - let json = serde_json::to_string(&ctx).unwrap(); - - // Should not serialize objectType and objectName when None - assert!(!json.contains("objectType")); - assert!(!json.contains("objectName")); - } - - #[test] - fn test_map_register_and_get() { - let map = QueryContextMap::new(); - - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - map.register(1, ctx.clone()); - - let retrieved = map.get(1).unwrap(); - assert_eq!(*retrieved.sql_text, "SELECT a/b FROM t"); - assert_eq!(retrieved.start_index, 7); - } - - #[test] - fn test_map_get_nonexistent() { - let map = QueryContextMap::new(); - assert!(map.get(999).is_none()); - } - - #[test] - fn test_map_clear() { - let map = QueryContextMap::new(); - - let ctx = QueryContext::new("SELECT a/b FROM t".to_string(), 7, 9, None, None, 1, 7); - - map.register(1, ctx); - assert_eq!(map.len(), 1); - - map.clear(); - assert_eq!(map.len(), 0); - assert!(map.is_empty()); - } - - // Verify that fragment() and format_summary() correctly handle SQL text that - // contains multi-byte characters - - #[test] - fn test_fragment_non_ascii_accented() { - // "é" is a 2-byte UTF-8 sequence (U+00E9). - // SQL: "SELECT café FROM t" - // 0123456789... - // char indices: c=7, a=8, f=9, é=10, ' '=11 ... FROM = 12.. - // start_index=7, stop_index=10 should yield "café" - let sql = "SELECT café FROM t".to_string(); - let ctx = QueryContext::new(sql, 7, 10, None, None, 1, 7); - assert_eq!(ctx.fragment(), "café"); - } -} +// Re-export all query context types from the spark-errors crate +pub use datafusion_comet_spark_errors::{create_query_context_map, QueryContext, QueryContextMap}; From 900361f30c19b99c2162f67488b8d183ea12edc2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 18:40:49 -0600 Subject: [PATCH 08/11] chore(deps): remove unused serde and thiserror from spark-expr --- native/Cargo.lock | 2 -- native/spark-expr/Cargo.toml | 2 -- 2 files changed, 4 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index f4531b36e8..91cb789404 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1965,9 +1965,7 @@ dependencies = [ "num", "rand 0.10.0", "regex", - "serde", "serde_json", - "thiserror 2.0.18", "tokio", "twox-hash", ] diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index 990553a60f..fcb1a07bf2 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -33,9 +33,7 @@ datafusion = { workspace = true } chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } -serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -thiserror = { workspace = true } datafusion-comet-spark-errors = { workspace = true } futures = { workspace = true } twox-hash = "2.1.2" From 7435dacd4ab5bb9575bfad3c38ae4850eb1732cd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 19:24:13 -0600 Subject: [PATCH 09/11] refactor: flatten jvm_bridge submodule into jvm-bridge crate root Move macros, JVMClasses, and helper types from the jvm_bridge submodule to the crate root, eliminating the redundant jvm_bridge::jvm_bridge path. --- native/core/src/lib.rs | 6 +- .../src/{jvm_bridge => }/batch_iterator.rs | 0 .../src/{jvm_bridge => }/comet_exec.rs | 0 .../src/{jvm_bridge => }/comet_metric_node.rs | 0 .../comet_task_memory_manager.rs | 0 native/jvm-bridge/src/jvm_bridge/mod.rs | 391 ------------------ native/jvm-bridge/src/lib.rs | 372 ++++++++++++++++- 7 files changed, 375 insertions(+), 394 deletions(-) rename native/jvm-bridge/src/{jvm_bridge => }/batch_iterator.rs (100%) rename native/jvm-bridge/src/{jvm_bridge => }/comet_exec.rs (100%) rename native/jvm-bridge/src/{jvm_bridge => }/comet_metric_node.rs (100%) rename native/jvm-bridge/src/{jvm_bridge => }/comet_task_memory_manager.rs (100%) delete mode 100644 native/jvm-bridge/src/jvm_bridge/mod.rs diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 8ee3139ff9..8b9c426bab 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -56,9 +56,13 @@ use mimalloc::MiMalloc; // Re-export from jvm-bridge crate for internal use pub use datafusion_comet_jvm_bridge::errors; -pub use datafusion_comet_jvm_bridge::jvm_bridge; pub use datafusion_comet_jvm_bridge::JAVA_VM; +/// Re-export jvm-bridge items under the `jvm_bridge` name for convenience. +pub mod jvm_bridge { + pub use datafusion_comet_jvm_bridge::*; +} + use errors::{try_unwrap_or_throw, CometError, CometResult}; #[macro_use] diff --git a/native/jvm-bridge/src/jvm_bridge/batch_iterator.rs b/native/jvm-bridge/src/batch_iterator.rs similarity index 100% rename from native/jvm-bridge/src/jvm_bridge/batch_iterator.rs rename to native/jvm-bridge/src/batch_iterator.rs diff --git a/native/jvm-bridge/src/jvm_bridge/comet_exec.rs b/native/jvm-bridge/src/comet_exec.rs similarity index 100% rename from native/jvm-bridge/src/jvm_bridge/comet_exec.rs rename to native/jvm-bridge/src/comet_exec.rs diff --git a/native/jvm-bridge/src/jvm_bridge/comet_metric_node.rs b/native/jvm-bridge/src/comet_metric_node.rs similarity index 100% rename from native/jvm-bridge/src/jvm_bridge/comet_metric_node.rs rename to native/jvm-bridge/src/comet_metric_node.rs diff --git a/native/jvm-bridge/src/jvm_bridge/comet_task_memory_manager.rs b/native/jvm-bridge/src/comet_task_memory_manager.rs similarity index 100% rename from native/jvm-bridge/src/jvm_bridge/comet_task_memory_manager.rs rename to native/jvm-bridge/src/comet_task_memory_manager.rs diff --git a/native/jvm-bridge/src/jvm_bridge/mod.rs b/native/jvm-bridge/src/jvm_bridge/mod.rs deleted file mode 100644 index c561197627..0000000000 --- a/native/jvm-bridge/src/jvm_bridge/mod.rs +++ /dev/null @@ -1,391 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! JNI JVM related functions - -use crate::errors::CometResult; - -use jni::objects::JClass; -use jni::{ - errors::Error, - objects::{JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned}, - signature::ReturnType, - AttachGuard, JNIEnv, -}; -use once_cell::sync::OnceCell; - -/// Macro for converting JNI Error to Comet Error. -#[macro_export] -macro_rules! jni_map_error { - ($env:expr, $result:expr) => {{ - match $result { - Ok(result) => datafusion::error::Result::Ok(result), - Err(jni_error) => Err($crate::errors::CometError::JNI { source: jni_error }), - } - }}; -} - -/// Macro for converting Rust types to JNI types. -#[macro_export] -macro_rules! jvalues { - ($($args:expr,)* $(,)?) => {{ - &[$(jni::objects::JValue::from($args).as_jni()),*] as &[jni::sys::jvalue] - }} -} - -/// Macro for calling a JNI method. -/// The syntax is: -/// jni_call!(env, comet_metric_node(metric_node).add(jname, value) -> ())?; -/// comet_metric_node is the class name stored in [[JVMClasses]]. -/// metric_node is the Java object on which the method is called. -/// add is the method name. -/// jname and value are the arguments. -#[macro_export] -macro_rules! jni_call { - ($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ - let method_id = paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] - }; - let ret_type = paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] - }.clone(); - let args = $crate::jvalues!($($args,)*); - - // Call the JVM method and obtain the returned value - let ret = $env.call_method_unchecked($obj, method_id, ret_type, args); - - // Check if JVM has thrown any exception, and handle it if so. - let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? { - Err(exception.into()) - } else { - $crate::jni_map_error!($env, ret) - }; - - result.and_then(|result| $crate::jni_map_error!($env, <$ret>::try_from(result))) - }} -} - -#[macro_export] -macro_rules! jni_static_call { - ($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ - let clazz = &paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] - }; - let method_id = paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] - }; - let ret_type = paste::paste! { - $crate::jvm_bridge::JVMClasses::get().[<$clsname>].[] - }.clone(); - let args = $crate::jvalues!($($args,)*); - - // Call the JVM static method and obtain the returned value - let ret = $env.call_static_method_unchecked(clazz, method_id, ret_type, args); - - // Check if JVM has thrown any exception, and handle it if so. - let result = if let Some(exception) = $crate::jvm_bridge::check_exception($env)? { - Err(exception.into()) - } else { - $crate::jni_map_error!($env, ret) - }; - - result.and_then(|result| $crate::jni_map_error!($env, <$ret>::try_from(result))) - }} -} - -/// Macro for creating a new global reference. -#[macro_export] -macro_rules! jni_new_global_ref { - ($env:expr, $obj:expr) => {{ - $crate::jni_map_error!($env, $env.new_global_ref($obj)) - }}; -} - -/// Wrapper for JString. Because we cannot implement `TryFrom` trait for `JString` as they -/// are defined in different crates. -pub struct StringWrapper<'a> { - value: JString<'a>, -} - -impl<'a> StringWrapper<'a> { - pub fn new(value: JString<'a>) -> StringWrapper<'a> { - Self { value } - } - - pub fn get(&self) -> &JString<'_> { - &self.value - } -} - -pub struct BinaryWrapper<'a> { - value: JObject<'a>, -} - -impl<'a> BinaryWrapper<'a> { - pub fn new(value: JObject<'a>) -> BinaryWrapper<'a> { - Self { value } - } - - pub fn get(&self) -> &JObject<'_> { - &self.value - } -} - -impl<'a> TryFrom> for StringWrapper<'a> { - type Error = Error; - - fn try_from(value: JValueOwned<'a>) -> Result, Error> { - match value { - JValueGen::Object(b) => Ok(StringWrapper::new(JString::from(b))), - _ => Err(Error::WrongJValueType("object", value.type_name())), - } - } -} - -impl<'a> TryFrom> for BinaryWrapper<'a> { - type Error = Error; - - fn try_from(value: JValueOwned<'a>) -> Result, Error> { - match value { - JValueGen::Object(b) => Ok(BinaryWrapper::new(b)), - _ => Err(Error::WrongJValueType("object", value.type_name())), - } - } -} - -mod comet_exec; -pub use comet_exec::*; -mod batch_iterator; -mod comet_metric_node; -mod comet_task_memory_manager; - -use crate::{errors::CometError, JAVA_VM}; -use batch_iterator::CometBatchIterator; -pub use comet_metric_node::*; -pub use comet_task_memory_manager::*; - -/// The JVM classes that are used in the JNI calls. -#[allow(dead_code)] // we need to keep references to Java items to prevent GC -pub struct JVMClasses<'a> { - /// Cached JClass for "java.lang.Object" - java_lang_object: JClass<'a>, - /// Cached JClass for "java.lang.Class" - java_lang_class: JClass<'a>, - /// Cached JClass for "java.lang.Throwable" - java_lang_throwable: JClass<'a>, - /// Cached method ID for "java.lang.Object#getClass" - object_get_class_method: JMethodID, - /// Cached method ID for "java.lang.Class#getName" - class_get_name_method: JMethodID, - /// Cached method ID for "java.lang.Throwable#getMessage" - throwable_get_message_method: JMethodID, - /// Cached method ID for "java.lang.Throwable#getCause" - throwable_get_cause_method: JMethodID, - - /// The CometMetricNode class. Used for updating the metrics. - pub comet_metric_node: CometMetricNode<'a>, - /// The static CometExec class. Used for getting the subquery result. - pub comet_exec: CometExec<'a>, - /// The CometBatchIterator class. Used for iterating over the batches. - pub comet_batch_iterator: CometBatchIterator<'a>, - /// The CometTaskMemoryManager used for interacting with JVM side to - /// acquire & release native memory. - pub comet_task_memory_manager: CometTaskMemoryManager<'a>, -} - -unsafe impl Send for JVMClasses<'_> {} - -unsafe impl Sync for JVMClasses<'_> {} - -/// Keeps global references to JVM classes. Used for JNI calls to JVM. -static JVM_CLASSES: OnceCell = OnceCell::new(); - -impl JVMClasses<'_> { - /// Creates a new JVMClasses struct. - pub fn init(env: &mut JNIEnv) { - JVM_CLASSES.get_or_init(|| { - // A hack to make the `JNIEnv` static. It is not safe but we don't really use the - // `JNIEnv` except for creating the global references of the classes. - let env = unsafe { std::mem::transmute::<&mut JNIEnv, &'static mut JNIEnv>(env) }; - - let java_lang_object = env.find_class("java/lang/Object").unwrap(); - let object_get_class_method = env - .get_method_id(&java_lang_object, "getClass", "()Ljava/lang/Class;") - .unwrap(); - - let java_lang_class = env.find_class("java/lang/Class").unwrap(); - let class_get_name_method = env - .get_method_id(&java_lang_class, "getName", "()Ljava/lang/String;") - .unwrap(); - - let java_lang_throwable = env.find_class("java/lang/Throwable").unwrap(); - let throwable_get_message_method = env - .get_method_id(&java_lang_throwable, "getMessage", "()Ljava/lang/String;") - .unwrap(); - - let throwable_get_cause_method = env - .get_method_id(&java_lang_throwable, "getCause", "()Ljava/lang/Throwable;") - .unwrap(); - - // SAFETY: According to the documentation for `JMethodID`, it is our - // responsibility to maintain a reference to the `JClass` instances where the - // methods were accessed from to prevent the methods from being garbage-collected - JVMClasses { - java_lang_object, - java_lang_class, - java_lang_throwable, - object_get_class_method, - class_get_name_method, - throwable_get_message_method, - throwable_get_cause_method, - comet_metric_node: CometMetricNode::new(env).unwrap(), - comet_exec: CometExec::new(env).unwrap(), - comet_batch_iterator: CometBatchIterator::new(env).unwrap(), - comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), - } - }); - } - - pub fn get() -> &'static JVMClasses<'static> { - debug_assert!( - JVM_CLASSES.get().is_some(), - "JVMClasses::get: not initialized" - ); - unsafe { JVM_CLASSES.get_unchecked() } - } - - /// Gets the JNIEnv for the current thread. - pub fn get_env() -> CometResult> { - debug_assert!( - JAVA_VM.get().is_some(), - "JVMClasses::get_env: JAVA_VM not initialized" - ); - unsafe { - let java_vm = JAVA_VM.get_unchecked(); - java_vm.attach_current_thread().map_err(|e| { - CometError::Internal(format!( - "JVMClasses::get_env() failed to attach current thread: {e}" - )) - }) - } - } -} - -pub fn check_exception(env: &mut JNIEnv) -> CometResult> { - let result = if env.exception_check()? { - let exception = env.exception_occurred()?; - env.exception_clear()?; - let exception_err = convert_exception(env, &exception)?; - Some(exception_err) - } else { - None - }; - - Ok(result) -} - -/// get the class name of the exception by: -/// 1. get the `Class` object of the input `throwable` via `Object#getClass` method -/// 2. get the exception class name via calling `Class#getName` on the above object -fn get_throwable_class_name( - env: &mut JNIEnv, - jvm_classes: &JVMClasses, - throwable: &JThrowable, -) -> CometResult { - unsafe { - let class_obj = env - .call_method_unchecked( - throwable, - jvm_classes.object_get_class_method, - ReturnType::Object, - &[], - )? - .l()?; - let class_name = env - .call_method_unchecked( - class_obj, - jvm_classes.class_get_name_method, - ReturnType::Object, - &[], - )? - .l()? - .into(); - let class_name_str = env.get_string(&class_name)?.into(); - - Ok(class_name_str) - } -} - -/// Get the exception message via calling `Throwable#getMessage` on the throwable object -fn get_throwable_message( - env: &mut JNIEnv, - jvm_classes: &JVMClasses, - throwable: &JThrowable, -) -> CometResult { - unsafe { - let message: JString = env - .call_method_unchecked( - throwable, - jvm_classes.throwable_get_message_method, - ReturnType::Object, - &[], - )? - .l()? - .into(); - let message_str = if !message.is_null() { - env.get_string(&message)?.into() - } else { - String::from("null") - }; - - let cause: JThrowable = env - .call_method_unchecked( - throwable, - jvm_classes.throwable_get_cause_method, - ReturnType::Object, - &[], - )? - .l()? - .into(); - - if !cause.is_null() { - let cause_class_name = get_throwable_class_name(env, jvm_classes, &cause)?; - let cause_message = get_throwable_message(env, jvm_classes, &cause)?; - Ok(format!( - "{message_str}\nCaused by: {cause_class_name}: {cause_message}" - )) - } else { - Ok(message_str) - } - } -} - -/// Given a `JThrowable` which is thrown from calling a Java method on the native side, -/// this converts it into a `CometError::JavaException` with the exception class name -/// and exception message. This error can then be populated to the JVM side to let -/// users know the cause of the native side error. -pub fn convert_exception(env: &mut JNIEnv, throwable: &JThrowable) -> CometResult { - let cache = JVMClasses::get(); - let exception_class_name_str = get_throwable_class_name(env, cache, throwable)?; - let message_str = get_throwable_message(env, cache, throwable)?; - - Ok(CometError::JavaException { - class: exception_class_name_str, - msg: message_str, - throwable: env.new_global_ref(throwable)?, - }) -} diff --git a/native/jvm-bridge/src/lib.rs b/native/jvm-bridge/src/lib.rs index 19bfb39cfa..0ae38f334e 100644 --- a/native/jvm-bridge/src/lib.rs +++ b/native/jvm-bridge/src/lib.rs @@ -21,11 +21,379 @@ #![allow(clippy::result_large_err)] -use jni::JavaVM; +use jni::objects::JClass; +use jni::{ + errors::Error, + objects::{JMethodID, JObject, JString, JThrowable, JValueGen, JValueOwned}, + signature::ReturnType, + AttachGuard, JNIEnv, JavaVM, +}; use once_cell::sync::OnceCell; +use errors::{CometError, CometResult}; + pub mod errors; -pub mod jvm_bridge; /// Global reference to the Java VM, initialized during native library setup. pub static JAVA_VM: OnceCell = OnceCell::new(); + +/// Macro for converting JNI Error to Comet Error. +#[macro_export] +macro_rules! jni_map_error { + ($env:expr, $result:expr) => {{ + match $result { + Ok(result) => datafusion::error::Result::Ok(result), + Err(jni_error) => Err($crate::errors::CometError::JNI { source: jni_error }), + } + }}; +} + +/// Macro for converting Rust types to JNI types. +#[macro_export] +macro_rules! jvalues { + ($($args:expr,)* $(,)?) => {{ + &[$(jni::objects::JValue::from($args).as_jni()),*] as &[jni::sys::jvalue] + }} +} + +/// Macro for calling a JNI method. +/// The syntax is: +/// jni_call!(env, comet_metric_node(metric_node).add(jname, value) -> ())?; +/// comet_metric_node is the class name stored in [[JVMClasses]]. +/// metric_node is the Java object on which the method is called. +/// add is the method name. +/// jname and value are the arguments. +#[macro_export] +macro_rules! jni_call { + ($env:expr, $clsname:ident($obj:expr).$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ + let method_id = paste::paste! { + $crate::JVMClasses::get().[<$clsname>].[] + }; + let ret_type = paste::paste! { + $crate::JVMClasses::get().[<$clsname>].[] + }.clone(); + let args = $crate::jvalues!($($args,)*); + + // Call the JVM method and obtain the returned value + let ret = $env.call_method_unchecked($obj, method_id, ret_type, args); + + // Check if JVM has thrown any exception, and handle it if so. + let result = if let Some(exception) = $crate::check_exception($env)? { + Err(exception.into()) + } else { + $crate::jni_map_error!($env, ret) + }; + + result.and_then(|result| $crate::jni_map_error!($env, <$ret>::try_from(result))) + }} +} + +#[macro_export] +macro_rules! jni_static_call { + ($env:expr, $clsname:ident.$method:ident($($args:expr),* $(,)?) -> $ret:ty) => {{ + let clazz = &paste::paste! { + $crate::JVMClasses::get().[<$clsname>].[] + }; + let method_id = paste::paste! { + $crate::JVMClasses::get().[<$clsname>].[] + }; + let ret_type = paste::paste! { + $crate::JVMClasses::get().[<$clsname>].[] + }.clone(); + let args = $crate::jvalues!($($args,)*); + + // Call the JVM static method and obtain the returned value + let ret = $env.call_static_method_unchecked(clazz, method_id, ret_type, args); + + // Check if JVM has thrown any exception, and handle it if so. + let result = if let Some(exception) = $crate::check_exception($env)? { + Err(exception.into()) + } else { + $crate::jni_map_error!($env, ret) + }; + + result.and_then(|result| $crate::jni_map_error!($env, <$ret>::try_from(result))) + }} +} + +/// Macro for creating a new global reference. +#[macro_export] +macro_rules! jni_new_global_ref { + ($env:expr, $obj:expr) => {{ + $crate::jni_map_error!($env, $env.new_global_ref($obj)) + }}; +} + +/// Wrapper for JString. Because we cannot implement `TryFrom` trait for `JString` as they +/// are defined in different crates. +pub struct StringWrapper<'a> { + value: JString<'a>, +} + +impl<'a> StringWrapper<'a> { + pub fn new(value: JString<'a>) -> StringWrapper<'a> { + Self { value } + } + + pub fn get(&self) -> &JString<'_> { + &self.value + } +} + +pub struct BinaryWrapper<'a> { + value: JObject<'a>, +} + +impl<'a> BinaryWrapper<'a> { + pub fn new(value: JObject<'a>) -> BinaryWrapper<'a> { + Self { value } + } + + pub fn get(&self) -> &JObject<'_> { + &self.value + } +} + +impl<'a> TryFrom> for StringWrapper<'a> { + type Error = Error; + + fn try_from(value: JValueOwned<'a>) -> Result, Error> { + match value { + JValueGen::Object(b) => Ok(StringWrapper::new(JString::from(b))), + _ => Err(Error::WrongJValueType("object", value.type_name())), + } + } +} + +impl<'a> TryFrom> for BinaryWrapper<'a> { + type Error = Error; + + fn try_from(value: JValueOwned<'a>) -> Result, Error> { + match value { + JValueGen::Object(b) => Ok(BinaryWrapper::new(b)), + _ => Err(Error::WrongJValueType("object", value.type_name())), + } + } +} + +mod comet_exec; +pub use comet_exec::*; +mod batch_iterator; +mod comet_metric_node; +mod comet_task_memory_manager; + +use batch_iterator::CometBatchIterator; +pub use comet_metric_node::*; +pub use comet_task_memory_manager::*; + +/// The JVM classes that are used in the JNI calls. +#[allow(dead_code)] // we need to keep references to Java items to prevent GC +pub struct JVMClasses<'a> { + /// Cached JClass for "java.lang.Object" + java_lang_object: JClass<'a>, + /// Cached JClass for "java.lang.Class" + java_lang_class: JClass<'a>, + /// Cached JClass for "java.lang.Throwable" + java_lang_throwable: JClass<'a>, + /// Cached method ID for "java.lang.Object#getClass" + object_get_class_method: JMethodID, + /// Cached method ID for "java.lang.Class#getName" + class_get_name_method: JMethodID, + /// Cached method ID for "java.lang.Throwable#getMessage" + throwable_get_message_method: JMethodID, + /// Cached method ID for "java.lang.Throwable#getCause" + throwable_get_cause_method: JMethodID, + + /// The CometMetricNode class. Used for updating the metrics. + pub comet_metric_node: CometMetricNode<'a>, + /// The static CometExec class. Used for getting the subquery result. + pub comet_exec: CometExec<'a>, + /// The CometBatchIterator class. Used for iterating over the batches. + pub comet_batch_iterator: CometBatchIterator<'a>, + /// The CometTaskMemoryManager used for interacting with JVM side to + /// acquire & release native memory. + pub comet_task_memory_manager: CometTaskMemoryManager<'a>, +} + +unsafe impl Send for JVMClasses<'_> {} + +unsafe impl Sync for JVMClasses<'_> {} + +/// Keeps global references to JVM classes. Used for JNI calls to JVM. +static JVM_CLASSES: OnceCell = OnceCell::new(); + +impl JVMClasses<'_> { + /// Creates a new JVMClasses struct. + pub fn init(env: &mut JNIEnv) { + JVM_CLASSES.get_or_init(|| { + // A hack to make the `JNIEnv` static. It is not safe but we don't really use the + // `JNIEnv` except for creating the global references of the classes. + let env = unsafe { std::mem::transmute::<&mut JNIEnv, &'static mut JNIEnv>(env) }; + + let java_lang_object = env.find_class("java/lang/Object").unwrap(); + let object_get_class_method = env + .get_method_id(&java_lang_object, "getClass", "()Ljava/lang/Class;") + .unwrap(); + + let java_lang_class = env.find_class("java/lang/Class").unwrap(); + let class_get_name_method = env + .get_method_id(&java_lang_class, "getName", "()Ljava/lang/String;") + .unwrap(); + + let java_lang_throwable = env.find_class("java/lang/Throwable").unwrap(); + let throwable_get_message_method = env + .get_method_id(&java_lang_throwable, "getMessage", "()Ljava/lang/String;") + .unwrap(); + + let throwable_get_cause_method = env + .get_method_id(&java_lang_throwable, "getCause", "()Ljava/lang/Throwable;") + .unwrap(); + + // SAFETY: According to the documentation for `JMethodID`, it is our + // responsibility to maintain a reference to the `JClass` instances where the + // methods were accessed from to prevent the methods from being garbage-collected + JVMClasses { + java_lang_object, + java_lang_class, + java_lang_throwable, + object_get_class_method, + class_get_name_method, + throwable_get_message_method, + throwable_get_cause_method, + comet_metric_node: CometMetricNode::new(env).unwrap(), + comet_exec: CometExec::new(env).unwrap(), + comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), + } + }); + } + + pub fn get() -> &'static JVMClasses<'static> { + debug_assert!( + JVM_CLASSES.get().is_some(), + "JVMClasses::get: not initialized" + ); + unsafe { JVM_CLASSES.get_unchecked() } + } + + /// Gets the JNIEnv for the current thread. + pub fn get_env() -> CometResult> { + debug_assert!( + JAVA_VM.get().is_some(), + "JVMClasses::get_env: JAVA_VM not initialized" + ); + unsafe { + let java_vm = JAVA_VM.get_unchecked(); + java_vm.attach_current_thread().map_err(|e| { + CometError::Internal(format!( + "JVMClasses::get_env() failed to attach current thread: {e}" + )) + }) + } + } +} + +pub fn check_exception(env: &mut JNIEnv) -> CometResult> { + let result = if env.exception_check()? { + let exception = env.exception_occurred()?; + env.exception_clear()?; + let exception_err = convert_exception(env, &exception)?; + Some(exception_err) + } else { + None + }; + + Ok(result) +} + +/// get the class name of the exception by: +/// 1. get the `Class` object of the input `throwable` via `Object#getClass` method +/// 2. get the exception class name via calling `Class#getName` on the above object +fn get_throwable_class_name( + env: &mut JNIEnv, + jvm_classes: &JVMClasses, + throwable: &JThrowable, +) -> CometResult { + unsafe { + let class_obj = env + .call_method_unchecked( + throwable, + jvm_classes.object_get_class_method, + ReturnType::Object, + &[], + )? + .l()?; + let class_name = env + .call_method_unchecked( + class_obj, + jvm_classes.class_get_name_method, + ReturnType::Object, + &[], + )? + .l()? + .into(); + let class_name_str = env.get_string(&class_name)?.into(); + + Ok(class_name_str) + } +} + +/// Get the exception message via calling `Throwable#getMessage` on the throwable object +fn get_throwable_message( + env: &mut JNIEnv, + jvm_classes: &JVMClasses, + throwable: &JThrowable, +) -> CometResult { + unsafe { + let message: JString = env + .call_method_unchecked( + throwable, + jvm_classes.throwable_get_message_method, + ReturnType::Object, + &[], + )? + .l()? + .into(); + let message_str = if !message.is_null() { + env.get_string(&message)?.into() + } else { + String::from("null") + }; + + let cause: JThrowable = env + .call_method_unchecked( + throwable, + jvm_classes.throwable_get_cause_method, + ReturnType::Object, + &[], + )? + .l()? + .into(); + + if !cause.is_null() { + let cause_class_name = get_throwable_class_name(env, jvm_classes, &cause)?; + let cause_message = get_throwable_message(env, jvm_classes, &cause)?; + Ok(format!( + "{message_str}\nCaused by: {cause_class_name}: {cause_message}" + )) + } else { + Ok(message_str) + } + } +} + +/// Given a `JThrowable` which is thrown from calling a Java method on the native side, +/// this converts it into a `CometError::JavaException` with the exception class name +/// and exception message. This error can then be populated to the JVM side to let +/// users know the cause of the native side error. +pub fn convert_exception(env: &mut JNIEnv, throwable: &JThrowable) -> CometResult { + let cache = JVMClasses::get(); + let exception_class_name_str = get_throwable_class_name(env, cache, throwable)?; + let message_str = get_throwable_message(env, cache, throwable)?; + + Ok(CometError::JavaException { + class: exception_class_name_str, + msg: message_str, + throwable: env.new_global_ref(throwable)?, + }) +} From 4fbcfa6a14749d1728f4f08fc0a8bc22be0a6e34 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 19:56:49 -0600 Subject: [PATCH 10/11] refactor: rename jvm-bridge crate to jni-bridge --- native/Cargo.lock | 4 ++-- native/Cargo.toml | 6 +++--- native/core/Cargo.toml | 2 +- native/core/src/lib.rs | 8 ++++---- native/{jvm-bridge => jni-bridge}/Cargo.toml | 4 ++-- native/{jvm-bridge => jni-bridge}/src/batch_iterator.rs | 0 native/{jvm-bridge => jni-bridge}/src/comet_exec.rs | 0 .../{jvm-bridge => jni-bridge}/src/comet_metric_node.rs | 0 .../src/comet_task_memory_manager.rs | 0 native/{jvm-bridge => jni-bridge}/src/errors.rs | 0 native/{jvm-bridge => jni-bridge}/src/lib.rs | 4 ++-- native/{jvm-bridge => jni-bridge}/testdata/backtrace.txt | 0 native/{jvm-bridge => jni-bridge}/testdata/stacktrace.txt | 0 13 files changed, 14 insertions(+), 14 deletions(-) rename native/{jvm-bridge => jni-bridge}/Cargo.toml (94%) rename native/{jvm-bridge => jni-bridge}/src/batch_iterator.rs (100%) rename native/{jvm-bridge => jni-bridge}/src/comet_exec.rs (100%) rename native/{jvm-bridge => jni-bridge}/src/comet_metric_node.rs (100%) rename native/{jvm-bridge => jni-bridge}/src/comet_task_memory_manager.rs (100%) rename native/{jvm-bridge => jni-bridge}/src/errors.rs (100%) rename native/{jvm-bridge => jni-bridge}/src/lib.rs (99%) rename native/{jvm-bridge => jni-bridge}/testdata/backtrace.txt (100%) rename native/{jvm-bridge => jni-bridge}/testdata/stacktrace.txt (100%) diff --git a/native/Cargo.lock b/native/Cargo.lock index 91cb789404..fdf915ac78 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1837,7 +1837,7 @@ dependencies = [ "crc32fast", "criterion", "datafusion", - "datafusion-comet-jvm-bridge", + "datafusion-comet-jni-bridge", "datafusion-comet-objectstore-hdfs", "datafusion-comet-proto", "datafusion-comet-spark-expr", @@ -1899,7 +1899,7 @@ dependencies = [ ] [[package]] -name = "datafusion-comet-jvm-bridge" +name = "datafusion-comet-jni-bridge" version = "0.14.0" dependencies = [ "arrow", diff --git a/native/Cargo.toml b/native/Cargo.toml index 960fb56509..fc64d522bf 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [workspace] -default-members = ["core", "spark-expr", "spark-errors", "proto", "jvm-bridge"] -members = ["core", "spark-expr", "spark-errors", "proto", "jvm-bridge", "hdfs", "fs-hdfs"] +default-members = ["core", "spark-expr", "spark-errors", "proto", "jni-bridge"] +members = ["core", "spark-expr", "spark-errors", "proto", "jni-bridge", "hdfs", "fs-hdfs"] resolver = "2" [workspace.package] @@ -44,7 +44,7 @@ datafusion-physical-expr-adapter = { version = "52.2.0" } datafusion-spark = { version = "52.2.0" } datafusion-comet-spark-expr = { path = "spark-expr" } datafusion-comet-spark-errors = { path = "spark-errors" } -datafusion-comet-jvm-bridge = { path = "jvm-bridge" } +datafusion-comet-jni-bridge = { path = "jni-bridge" } datafusion-comet-proto = { path = "proto" } chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.10" } diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 2780eb01d1..db8a75849e 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -65,7 +65,7 @@ once_cell = "1.18.0" crc32fast = "1.3.2" simd-adler32 = "0.3.7" datafusion-comet-spark-expr = { workspace = true } -datafusion-comet-jvm-bridge = { workspace = true } +datafusion-comet-jni-bridge = { workspace = true } datafusion-comet-proto = { workspace = true } object_store = { workspace = true } url = { workspace = true } diff --git a/native/core/src/lib.rs b/native/core/src/lib.rs index 8b9c426bab..1b87dc1dba 100644 --- a/native/core/src/lib.rs +++ b/native/core/src/lib.rs @@ -27,7 +27,7 @@ extern crate core; #[macro_use] -extern crate datafusion_comet_jvm_bridge; +extern crate datafusion_comet_jni_bridge; use jni::{ objects::{JClass, JString}, @@ -55,12 +55,12 @@ use tikv_jemallocator::Jemalloc; use mimalloc::MiMalloc; // Re-export from jvm-bridge crate for internal use -pub use datafusion_comet_jvm_bridge::errors; -pub use datafusion_comet_jvm_bridge::JAVA_VM; +pub use datafusion_comet_jni_bridge::errors; +pub use datafusion_comet_jni_bridge::JAVA_VM; /// Re-export jvm-bridge items under the `jvm_bridge` name for convenience. pub mod jvm_bridge { - pub use datafusion_comet_jvm_bridge::*; + pub use datafusion_comet_jni_bridge::*; } use errors::{try_unwrap_or_throw, CometError, CometResult}; diff --git a/native/jvm-bridge/Cargo.toml b/native/jni-bridge/Cargo.toml similarity index 94% rename from native/jvm-bridge/Cargo.toml rename to native/jni-bridge/Cargo.toml index a869dd1eee..fb47caecf8 100644 --- a/native/jvm-bridge/Cargo.toml +++ b/native/jni-bridge/Cargo.toml @@ -16,12 +16,12 @@ # under the License. [package] -name = "datafusion-comet-jvm-bridge" +name = "datafusion-comet-jni-bridge" version = { workspace = true } homepage = "https://datafusion.apache.org/comet" repository = "https://github.com/apache/datafusion-comet" authors = ["Apache DataFusion "] -description = "Apache DataFusion Comet: JVM bridge" +description = "Apache DataFusion Comet: JNI bridge" readme = "README.md" license = "Apache-2.0" edition = "2021" diff --git a/native/jvm-bridge/src/batch_iterator.rs b/native/jni-bridge/src/batch_iterator.rs similarity index 100% rename from native/jvm-bridge/src/batch_iterator.rs rename to native/jni-bridge/src/batch_iterator.rs diff --git a/native/jvm-bridge/src/comet_exec.rs b/native/jni-bridge/src/comet_exec.rs similarity index 100% rename from native/jvm-bridge/src/comet_exec.rs rename to native/jni-bridge/src/comet_exec.rs diff --git a/native/jvm-bridge/src/comet_metric_node.rs b/native/jni-bridge/src/comet_metric_node.rs similarity index 100% rename from native/jvm-bridge/src/comet_metric_node.rs rename to native/jni-bridge/src/comet_metric_node.rs diff --git a/native/jvm-bridge/src/comet_task_memory_manager.rs b/native/jni-bridge/src/comet_task_memory_manager.rs similarity index 100% rename from native/jvm-bridge/src/comet_task_memory_manager.rs rename to native/jni-bridge/src/comet_task_memory_manager.rs diff --git a/native/jvm-bridge/src/errors.rs b/native/jni-bridge/src/errors.rs similarity index 100% rename from native/jvm-bridge/src/errors.rs rename to native/jni-bridge/src/errors.rs diff --git a/native/jvm-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs similarity index 99% rename from native/jvm-bridge/src/lib.rs rename to native/jni-bridge/src/lib.rs index 0ae38f334e..456fbdf688 100644 --- a/native/jvm-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -//! JVM bridge for Apache DataFusion Comet. +//! JNI bridge for Apache DataFusion Comet. //! -//! This crate provides the JNI/JVM interaction layer used across Comet's native Rust crates. +//! This crate provides the JNI interaction layer used across Comet's native Rust crates. #![allow(clippy::result_large_err)] diff --git a/native/jvm-bridge/testdata/backtrace.txt b/native/jni-bridge/testdata/backtrace.txt similarity index 100% rename from native/jvm-bridge/testdata/backtrace.txt rename to native/jni-bridge/testdata/backtrace.txt diff --git a/native/jvm-bridge/testdata/stacktrace.txt b/native/jni-bridge/testdata/stacktrace.txt similarity index 100% rename from native/jvm-bridge/testdata/stacktrace.txt rename to native/jni-bridge/testdata/stacktrace.txt From 373e7dbd98d354af89ecfe6f0f95b14d9e1176e0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 11 Mar 2026 20:54:54 -0600 Subject: [PATCH 11/11] refactor: rename spark-errors crate to common --- native/Cargo.lock | 26 +++++++++---------- native/Cargo.toml | 6 ++--- native/{spark-errors => common}/Cargo.toml | 6 ++--- native/{spark-errors => common}/src/error.rs | 0 native/{spark-errors => common}/src/lib.rs | 0 .../src/query_context.rs | 0 native/jni-bridge/Cargo.toml | 2 +- native/jni-bridge/src/errors.rs | 2 +- native/spark-expr/Cargo.toml | 2 +- native/spark-expr/src/error.rs | 4 +-- native/spark-expr/src/query_context.rs | 4 +-- 11 files changed, 26 insertions(+), 26 deletions(-) rename native/{spark-errors => common}/Cargo.toml (89%) rename native/{spark-errors => common}/src/error.rs (100%) rename native/{spark-errors => common}/src/lib.rs (100%) rename native/{spark-errors => common}/src/query_context.rs (100%) diff --git a/native/Cargo.lock b/native/Cargo.lock index fdf915ac78..af1a096845 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1883,6 +1883,17 @@ dependencies = [ "zstd", ] +[[package]] +name = "datafusion-comet-common" +version = "0.14.0" +dependencies = [ + "arrow", + "datafusion", + "serde", + "serde_json", + "thiserror 2.0.18", +] + [[package]] name = "datafusion-comet-fs-hdfs3" version = "0.1.12" @@ -1905,7 +1916,7 @@ dependencies = [ "arrow", "assertables", "datafusion", - "datafusion-comet-spark-errors", + "datafusion-comet-common", "jni", "lazy_static", "once_cell", @@ -1938,17 +1949,6 @@ dependencies = [ "prost-build", ] -[[package]] -name = "datafusion-comet-spark-errors" -version = "0.14.0" -dependencies = [ - "arrow", - "datafusion", - "serde", - "serde_json", - "thiserror 2.0.18", -] - [[package]] name = "datafusion-comet-spark-expr" version = "0.14.0" @@ -1959,7 +1959,7 @@ dependencies = [ "chrono-tz", "criterion", "datafusion", - "datafusion-comet-spark-errors", + "datafusion-comet-common", "futures", "hex", "num", diff --git a/native/Cargo.toml b/native/Cargo.toml index fc64d522bf..fe9d524ecc 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [workspace] -default-members = ["core", "spark-expr", "spark-errors", "proto", "jni-bridge"] -members = ["core", "spark-expr", "spark-errors", "proto", "jni-bridge", "hdfs", "fs-hdfs"] +default-members = ["core", "spark-expr", "common", "proto", "jni-bridge"] +members = ["core", "spark-expr", "common", "proto", "jni-bridge", "hdfs", "fs-hdfs"] resolver = "2" [workspace.package] @@ -43,7 +43,7 @@ datafusion-datasource = { version = "52.2.0" } datafusion-physical-expr-adapter = { version = "52.2.0" } datafusion-spark = { version = "52.2.0" } datafusion-comet-spark-expr = { path = "spark-expr" } -datafusion-comet-spark-errors = { path = "spark-errors" } +datafusion-comet-common = { path = "common" } datafusion-comet-jni-bridge = { path = "jni-bridge" } datafusion-comet-proto = { path = "proto" } chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/native/spark-errors/Cargo.toml b/native/common/Cargo.toml similarity index 89% rename from native/spark-errors/Cargo.toml rename to native/common/Cargo.toml index 1b0e40053f..3bbc44856e 100644 --- a/native/spark-errors/Cargo.toml +++ b/native/common/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [package] -name = "datafusion-comet-spark-errors" -description = "Apache DataFusion Comet: Spark error types" +name = "datafusion-comet-common" +description = "Apache DataFusion Comet: common types shared across crates" version = { workspace = true } homepage = { workspace = true } repository = { workspace = true } @@ -36,5 +36,5 @@ serde_json = "1.0" thiserror = { workspace = true } [lib] -name = "datafusion_comet_spark_errors" +name = "datafusion_comet_common" path = "src/lib.rs" diff --git a/native/spark-errors/src/error.rs b/native/common/src/error.rs similarity index 100% rename from native/spark-errors/src/error.rs rename to native/common/src/error.rs diff --git a/native/spark-errors/src/lib.rs b/native/common/src/lib.rs similarity index 100% rename from native/spark-errors/src/lib.rs rename to native/common/src/lib.rs diff --git a/native/spark-errors/src/query_context.rs b/native/common/src/query_context.rs similarity index 100% rename from native/spark-errors/src/query_context.rs rename to native/common/src/query_context.rs diff --git a/native/jni-bridge/Cargo.toml b/native/jni-bridge/Cargo.toml index fb47caecf8..0c50825667 100644 --- a/native/jni-bridge/Cargo.toml +++ b/native/jni-bridge/Cargo.toml @@ -39,7 +39,7 @@ lazy_static = "1.4.0" once_cell = "1.18.0" paste = "1.0.14" prost = "0.14.3" -datafusion-comet-spark-errors = { workspace = true } +datafusion-comet-common = { workspace = true } [dev-dependencies] jni = { version = "0.21", features = ["invocation"] } diff --git a/native/jni-bridge/src/errors.rs b/native/jni-bridge/src/errors.rs index 62ce3b05fb..aff471e245 100644 --- a/native/jni-bridge/src/errors.rs +++ b/native/jni-bridge/src/errors.rs @@ -19,7 +19,7 @@ use arrow::error::ArrowError; use datafusion::common::DataFusionError; -use datafusion_comet_spark_errors::{SparkError, SparkErrorWithContext}; +use datafusion_comet_common::{SparkError, SparkErrorWithContext}; use jni::errors::{Exception, ToException}; use regex::Regex; diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index fcb1a07bf2..a249b7de3c 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -34,7 +34,7 @@ chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } serde_json = "1.0" -datafusion-comet-spark-errors = { workspace = true } +datafusion-comet-common = { workspace = true } futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs index 9d0e765b92..bb87915c7b 100644 --- a/native/spark-expr/src/error.rs +++ b/native/spark-expr/src/error.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// Re-export all error types from the spark-errors crate -pub use datafusion_comet_spark_errors::{ +// Re-export all error types from the common crate +pub use datafusion_comet_common::{ decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult, }; diff --git a/native/spark-expr/src/query_context.rs b/native/spark-expr/src/query_context.rs index 431ba54867..7d96a3f1e5 100644 --- a/native/spark-expr/src/query_context.rs +++ b/native/spark-expr/src/query_context.rs @@ -15,5 +15,5 @@ // specific language governing permissions and limitations // under the License. -// Re-export all query context types from the spark-errors crate -pub use datafusion_comet_spark_errors::{create_query_context_map, QueryContext, QueryContextMap}; +// Re-export all query context types from the common crate +pub use datafusion_comet_common::{create_query_context_map, QueryContext, QueryContextMap};