diff --git a/native/core/src/execution/expressions/array.rs b/native/core/src/execution/expressions/array.rs new file mode 100644 index 0000000000..7e8921c8d8 --- /dev/null +++ b/native/core/src/execution/expressions/array.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Array expression builders + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::Expr; +use datafusion_comet_spark_expr::{ArrayExistsExpr, LambdaVariableExpr}; + +use crate::execution::operators::ExecutionError; +use crate::execution::planner::expression_registry::ExpressionBuilder; +use crate::execution::planner::PhysicalPlanner; +use crate::execution::serde::to_arrow_datatype; +use crate::extract_expr; + +pub struct ArrayExistsBuilder; + +impl ExpressionBuilder for ArrayExistsBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, ArrayExists); + let array_expr = + planner.create_expr(expr.array.as_ref().unwrap(), Arc::clone(&input_schema))?; + let lambda_body = planner.create_expr(expr.lambda_body.as_ref().unwrap(), input_schema)?; + Ok(Arc::new(ArrayExistsExpr::new( + array_expr, + lambda_body, + expr.follow_three_valued_logic, + ))) + } +} + +pub struct LambdaVariableBuilder; + +impl ExpressionBuilder for LambdaVariableBuilder { + fn build( + &self, + spark_expr: &Expr, + _input_schema: SchemaRef, + _planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, LambdaVariable); + let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(LambdaVariableExpr::new(data_type))) + } +} diff --git a/native/core/src/execution/expressions/mod.rs b/native/core/src/execution/expressions/mod.rs index 563d62e91b..02f50f1b7e 100644 --- a/native/core/src/execution/expressions/mod.rs +++ b/native/core/src/execution/expressions/mod.rs @@ -18,6 +18,7 @@ //! Native DataFusion expressions pub mod arithmetic; +pub mod array; pub mod bitwise; pub mod comparison; pub mod logical; diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs index 34aa3de179..8097eb94d1 100644 --- a/native/core/src/execution/planner/expression_registry.rs +++ b/native/core/src/execution/planner/expression_registry.rs @@ -103,6 +103,8 @@ pub enum ExpressionType { Randn, SparkPartitionId, MonotonicallyIncreasingId, + ArrayExists, + LambdaVariable, // Time functions Hour, @@ -184,6 +186,9 @@ impl ExpressionRegistry { // Register temporal expressions self.register_temporal_expressions(); + + // Register array expressions + self.register_array_expressions(); } /// Register arithmetic expression builders @@ -306,6 +311,18 @@ impl ExpressionRegistry { ); } + /// Register array expression builders + fn register_array_expressions(&mut self) { + use crate::execution::expressions::array::*; + + self.builders + .insert(ExpressionType::ArrayExists, Box::new(ArrayExistsBuilder)); + self.builders.insert( + ExpressionType::LambdaVariable, + Box::new(LambdaVariableBuilder), + ); + } + /// Extract expression type from Spark protobuf expression fn get_expression_type(spark_expr: &Expr) -> Result { match spark_expr.expr_struct.as_ref() { @@ -370,6 +387,8 @@ impl ExpressionRegistry { Some(ExprStruct::MonotonicallyIncreasingId(_)) => { Ok(ExpressionType::MonotonicallyIncreasingId) } + Some(ExprStruct::ArrayExists(_)) => Ok(ExpressionType::ArrayExists), + Some(ExprStruct::LambdaVariable(_)) => Ok(ExpressionType::LambdaVariable), Some(ExprStruct::Hour(_)) => Ok(ExpressionType::Hour), Some(ExprStruct::Minute(_)) => Ok(ExpressionType::Minute), diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 944505ba1c..d3420fefd4 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -88,6 +88,8 @@ message Expr { UnixTimestamp unix_timestamp = 65; FromJson from_json = 66; ToCsv to_csv = 67; + ArrayExists array_exists = 68; + LambdaVariable lambda_variable = 69; } } @@ -440,3 +442,17 @@ message ArrayJoin { message Rand { int64 seed = 1; } + +message ArrayExists { + Expr array = 1; + Expr lambda_body = 2; + bool follow_three_valued_logic = 3; +} + +// Currently only supports a single lambda variable per expression. The variable +// is resolved by column index (always the last column in the expanded batch +// constructed by ArrayExistsExpr). Extending to multi-argument lambdas +// (e.g. transform(array, (x, i) -> ...)) would require adding an identifier. +message LambdaVariable { + DataType datatype = 1; +} diff --git a/native/spark-expr/src/array_funcs/array_exists.rs b/native/spark-expr/src/array_funcs/array_exists.rs new file mode 100644 index 0000000000..c0be034be2 --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_exists.rs @@ -0,0 +1,546 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BooleanArray, LargeListArray, ListArray}; +use arrow::buffer::NullBuffer; +use arrow::compute::kernels::take::take; +use arrow::datatypes::{DataType, Field, Schema, UInt32Type}; +use arrow::record_batch::RecordBatch; +use datafusion::common::{DataFusionError, Result as DataFusionResult}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; +use std::sync::Arc; + +const LAMBDA_VAR_COLUMN: &str = "__comet_lambda_var"; + +/// Decomposed list array: offsets as usize, values, and optional null buffer. +struct ListComponents { + offsets: Vec, + values: ArrayRef, + nulls: Option, +} + +impl ListComponents { + fn is_null(&self, row: usize) -> bool { + self.nulls.as_ref().is_some_and(|n| n.is_null(row)) + } +} + +fn decompose_list(array: &dyn Array) -> DataFusionResult { + if let Some(list) = array.as_any().downcast_ref::() { + Ok(ListComponents { + offsets: list.offsets().iter().map(|&o| o as usize).collect(), + values: Arc::clone(list.values()), + nulls: list.nulls().cloned(), + }) + } else if let Some(large) = array.as_any().downcast_ref::() { + Ok(ListComponents { + offsets: large.offsets().iter().map(|&o| o as usize).collect(), + values: Arc::clone(large.values()), + nulls: large.nulls().cloned(), + }) + } else { + Err(DataFusionError::Internal( + "ArrayExists expects a ListArray or LargeListArray input".to_string(), + )) + } +} + +/// Spark-compatible `array_exists(array, x -> predicate(x))`. +/// +/// Evaluates the lambda body vectorized over all elements in a single pass rather +/// than per-element to avoid repeated batch construction overhead. +#[derive(Debug, Eq)] +pub struct ArrayExistsExpr { + array_expr: Arc, + lambda_body: Arc, + follow_three_valued_logic: bool, +} + +impl Hash for ArrayExistsExpr { + fn hash(&self, state: &mut H) { + self.array_expr.hash(state); + self.lambda_body.hash(state); + self.follow_three_valued_logic.hash(state); + } +} + +impl PartialEq for ArrayExistsExpr { + fn eq(&self, other: &Self) -> bool { + self.array_expr.eq(&other.array_expr) + && self.lambda_body.eq(&other.lambda_body) + && self + .follow_three_valued_logic + .eq(&other.follow_three_valued_logic) + } +} + +impl ArrayExistsExpr { + pub fn new( + array_expr: Arc, + lambda_body: Arc, + follow_three_valued_logic: bool, + ) -> Self { + Self { + array_expr, + lambda_body, + follow_three_valued_logic, + } + } +} + +impl PhysicalExpr for ArrayExistsExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn nullable(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let num_rows = batch.num_rows(); + + let array_value = self.array_expr.evaluate(batch)?.into_array(num_rows)?; + let list = decompose_list(array_value.as_ref())?; + let total_elements = list.values.len(); + + if total_elements == 0 { + let mut result_builder = BooleanArray::builder(num_rows); + for row in 0..num_rows { + if list.is_null(row) { + result_builder.append_null(); + } else { + result_builder.append_value(false); + } + } + return Ok(ColumnarValue::Array(Arc::new(result_builder.finish()))); + } + + let mut repeat_indices = Vec::with_capacity(total_elements); + for row in 0..num_rows { + let start = list.offsets[row]; + let end = list.offsets[row + 1]; + for _ in start..end { + repeat_indices.push(row as u32); + } + } + + let repeat_indices_array = arrow::array::PrimitiveArray::::from(repeat_indices); + + let mut expanded_columns: Vec = Vec::with_capacity(batch.num_columns() + 1); + let mut expanded_fields: Vec> = Vec::with_capacity(batch.num_columns() + 1); + + for (i, col) in batch.columns().iter().enumerate() { + let expanded = take(col.as_ref(), &repeat_indices_array, None)?; + expanded_columns.push(expanded); + expanded_fields.push(Arc::new(batch.schema().field(i).clone())); + } + + let element_field = Arc::new(Field::new( + LAMBDA_VAR_COLUMN, + list.values.data_type().clone(), + true, + )); + expanded_columns.push(Arc::clone(&list.values)); + expanded_fields.push(element_field); + + let expanded_schema = Arc::new(Schema::new(expanded_fields)); + let expanded_batch = RecordBatch::try_new(expanded_schema, expanded_columns)?; + + let body_result = self + .lambda_body + .evaluate(&expanded_batch)? + .into_array(total_elements)?; + + let body_booleans = body_result + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "ArrayExists lambda body must return BooleanArray".to_string(), + ) + })?; + + let mut result_builder = BooleanArray::builder(num_rows); + for row in 0..num_rows { + if list.is_null(row) { + result_builder.append_null(); + continue; + } + + let start = list.offsets[row]; + let end = list.offsets[row + 1]; + + if start == end { + result_builder.append_value(false); + continue; + } + + let mut found_true = false; + let mut found_null = false; + + for idx in start..end { + if body_booleans.is_null(idx) { + found_null = true; + } else if body_booleans.value(idx) { + found_true = true; + break; + } + } + + if found_true { + result_builder.append_value(true); + } else if found_null && self.follow_three_valued_logic { + result_builder.append_null(); + } else { + result_builder.append_value(false); + } + } + + Ok(ColumnarValue::Array(Arc::new(result_builder.finish()))) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.array_expr, &self.lambda_body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + match children.len() { + 2 => Ok(Arc::new(ArrayExistsExpr::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.follow_three_valued_logic, + ))), + _ => Err(DataFusionError::Internal( + "ArrayExistsExpr should have exactly two children".to_string(), + )), + } + } +} + +impl Display for ArrayExistsExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ArrayExists [array: {:?}, lambda_body: {:?}]", + self.array_expr, self.lambda_body + ) + } +} + +#[derive(Debug, Eq)] +pub struct LambdaVariableExpr { + data_type: DataType, +} + +impl Hash for LambdaVariableExpr { + fn hash(&self, state: &mut H) { + self.data_type.hash(state); + } +} + +impl PartialEq for LambdaVariableExpr { + fn eq(&self, other: &Self) -> bool { + self.data_type == other.data_type + } +} + +impl LambdaVariableExpr { + pub fn new(data_type: DataType) -> Self { + Self { data_type } + } +} + +impl PhysicalExpr for LambdaVariableExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + // The lambda variable is always the last column, appended by ArrayExistsExpr + let idx = batch.num_columns() - 1; + Ok(ColumnarValue::Array(Arc::clone(batch.column(idx)))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + if children.is_empty() { + Ok(Arc::new(LambdaVariableExpr::new(self.data_type.clone()))) + } else { + Err(DataFusionError::Internal( + "LambdaVariableExpr should have no children".to_string(), + )) + } + } +} + +impl Display for LambdaVariableExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "LambdaVariable({})", self.data_type) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::array::ListArray; + use arrow::datatypes::Int32Type; + use datafusion::physical_expr::expressions::{Column, Literal}; + use datafusion::{ + common::ScalarValue, logical_expr::Operator, physical_expr::expressions::BinaryExpr, + }; + + fn make_lambda_var_expr() -> Arc { + Arc::new(LambdaVariableExpr::new(DataType::Int32)) + } + + fn make_gt_predicate(threshold: i32) -> Arc { + Arc::new(BinaryExpr::new( + make_lambda_var_expr(), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(threshold)))), + )) + } + + #[test] + fn test_basic_exists() -> DataFusionResult<()> { + // exists(array(1, 2, 3), x -> x > 2) = true + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); + assert!(!bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_empty_array() -> DataFusionResult<()> { + // exists(array(), x -> x > 0) = false + let list = + ListArray::from_iter_primitive::(vec![ + Some(Vec::>::new()), + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(0); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(!bools.value(0)); + assert!(!bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_null_array() -> DataFusionResult<()> { + // exists(null, x -> x > 0) = null + let list = + ListArray::from_iter_primitive::(vec![None::>>]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(0); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_three_valued_logic() -> DataFusionResult<()> { + // exists(array(1, null, 3), x -> x > 5) = null (three-valued logic) + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(5); + + // With three-valued logic: result should be null + let expr = ArrayExistsExpr::new(Arc::clone(&array_expr), Arc::clone(&lambda_body), true); + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.is_null(0)); + + // Without three-valued logic: result should be false + let expr2 = ArrayExistsExpr::new(array_expr, lambda_body, false); + let result2 = expr2.evaluate(&batch)?.into_array(1)?; + let bools2 = result2.as_any().downcast_ref::().unwrap(); + assert!(!bools2.is_null(0)); + assert!(!bools2.value(0)); + Ok(()) + } + + #[test] + fn test_null_elements_with_match() -> DataFusionResult<()> { + // exists(array(1, null, 3), x -> x > 2) = true (because 3 > 2) + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(!bools.is_null(0)); + assert!(bools.value(0)); + Ok(()) + } + + #[test] + fn test_multiple_rows() -> DataFusionResult<()> { + // Row 0: [1, 2, 3] -> x > 2 -> true + // Row 1: [1, 2] -> x > 2 -> false + // Row 2: null -> null + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(1), Some(2)]), + None, + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(3)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); + assert!(!bools.value(1)); + assert!(bools.is_null(2)); + Ok(()) + } + + #[test] + fn test_multi_column_batch() -> DataFusionResult<()> { + // Verify batch expansion works correctly with additional columns + use arrow::array::Int32Array; + + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(20)]), + Some(vec![Some(5)]), + ]); + let extra_col = Int32Array::from(vec![100, 200]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("arr", list.data_type().clone(), true), + Field::new("extra", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list), Arc::new(extra_col)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(15); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(2)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); // [10, 20] has 20 > 15 + assert!(!bools.value(1)); // [5] has no element > 15 + Ok(()) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 3ef50a252f..64d8ac934f 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +mod array_exists; mod array_insert; mod get_array_struct_fields; mod list_extract; mod size; +pub use array_exists::{ArrayExistsExpr, LambdaVariableExpr}; pub use array_insert::ArrayInsert; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9d13ccd9ed..6c4e606a52 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -50,6 +50,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArrayContains] -> CometArrayContains, classOf[ArrayDistinct] -> CometArrayDistinct, classOf[ArrayExcept] -> CometArrayExcept, + classOf[ArrayExists] -> CometArrayExists, classOf[ArrayFilter] -> CometArrayFilter, classOf[ArrayInsert] -> CometArrayInsert, classOf[ArrayIntersect] -> CometArrayIntersect, @@ -64,6 +65,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ElementAt] -> CometElementAt, classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, + classOf[NamedLambdaVariable] -> CometNamedLambdaVariable, classOf[Size] -> CometSize) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b7ebb9ba7b..240bad577e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayExists, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, NamedLambdaVariable, Reverse, Size} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -603,6 +603,80 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { } } +object CometArrayExists extends CometExpressionSerde[ArrayExists] { + + override def getSupportLevel(expr: ArrayExists): SupportLevel = { + val elementType = expr.argument.dataType.asInstanceOf[ArrayType].elementType + elementType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | DateType | TimestampType | TimestampNTZType | StringType => + Compatible() + case _ => Unsupported(Some(s"element type not supported: $elementType")) + } + } + + override def convert( + expr: ArrayExists, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayExprProto = exprToProto(expr.argument, inputs, binding) + if (arrayExprProto.isEmpty) { + withInfo(expr, expr.argument) + return None + } + + expr.function match { + case LambdaFunction(body, Seq(_: NamedLambdaVariable), _) => + val bodyProto = exprToProto(body, inputs, binding) + if (bodyProto.isEmpty) { + withInfo(expr, body) + return None + } + + val arrayExistsBuilder = ExprOuterClass.ArrayExists + .newBuilder() + .setArray(arrayExprProto.get) + .setLambdaBody(bodyProto.get) + .setFollowThreeValuedLogic(expr.followThreeValuedLogic) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayExists(arrayExistsBuilder) + .build()) + + case other => + withInfo(expr, s"Unsupported lambda function form: $other") + None + } + } + +} + +object CometNamedLambdaVariable extends CometExpressionSerde[NamedLambdaVariable] { + + override def convert( + expr: NamedLambdaVariable, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val dataTypeProto = serializeDataType(expr.dataType) + if (dataTypeProto.isEmpty) { + withInfo(expr, s"Cannot serialize data type: ${expr.dataType}") + return None + } + + val lambdaVarBuilder = ExprOuterClass.LambdaVariable + .newBuilder() + .setDatatype(dataTypeProto.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setLambdaVariable(lambdaVarBuilder) + .build()) + } +} + object CometSize extends CometExpressionSerde[Size] { override def getSupportLevel(expr: Size): SupportLevel = { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql b/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql new file mode 100644 index 0000000000..3ecbe7eaf0 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql @@ -0,0 +1,98 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_array_exists(arr_int array, arr_str array, arr_double array, arr_bool array, arr_long array, threshold int) USING parquet + +statement +INSERT INTO test_array_exists VALUES (array(1, 2, 3), array('a', 'bb', 'ccc'), array(1.5, 2.5, 3.5), array(false, false, true), array(100, 200, 300), 2), (array(1, 2), array('a', 'b'), array(0.5, 1.5), array(false, false), array(10, 20), 5), (array(), array(), array(), array(), array(), 0), (NULL, NULL, NULL, NULL, NULL, 1), (array(1, NULL, 3), array('a', NULL, 'ccc'), array(1.0, NULL, 3.0), array(true, NULL, false), array(100, NULL, 300), 2) + +-- basic: element satisfies predicate +query +SELECT exists(arr_int, x -> x > 2) FROM test_array_exists + +-- no match +query +SELECT exists(arr_int, x -> x > 100) FROM test_array_exists + +-- empty array returns false +query +SELECT exists(arr_int, x -> x > 0) FROM test_array_exists + +-- null array returns null +query +SELECT exists(arr_int, x -> x > 0) FROM test_array_exists WHERE arr_int IS NULL + +-- predicate referencing outer column +query +SELECT exists(arr_int, x -> x > threshold) FROM test_array_exists + +-- three-valued logic: null elements with no match -> null +query +SELECT exists(arr_int, x -> x > 5) FROM test_array_exists + +-- null elements but match exists -> true +query +SELECT exists(arr_int, x -> x > 2) FROM test_array_exists + +-- string type +query +SELECT exists(arr_str, x -> length(x) > 2) FROM test_array_exists + +-- double type +query +SELECT exists(arr_double, x -> x > 2.0) FROM test_array_exists + +-- boolean type +query +SELECT exists(arr_bool, x -> x) FROM test_array_exists + +-- long type +query +SELECT exists(arr_long, x -> x > 250) FROM test_array_exists + +-- literal arrays +query +SELECT exists(array(1, 2, 3), x -> x > 2) + +query +SELECT exists(array(1, 2, 3), x -> x > 5) + +-- empty literal array has NullType element type, which is unsupported +query spark_answer_only +SELECT exists(array(), x -> cast(x as int) > 0) + +query +SELECT exists(cast(NULL as array), x -> x > 0) + +-- null elements in literal array with three-valued logic +query +SELECT exists(array(1, NULL, 3), x -> x > 5) + +-- null elements in literal array with match +query +SELECT exists(array(1, NULL, 3), x -> x > 2) + +-- complex predicate +query +SELECT exists(arr_int, x -> x > 1 AND x < 3) FROM test_array_exists + +-- predicate with modulo +query +SELECT exists(arr_int, x -> x % 2 = 0) FROM test_array_exists diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index b22d0f72db..9bea8d25ef 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -922,4 +922,80 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + test("array_exists - DataFrame API") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array, threshold int) using parquet") + sql(s"insert into $table values (array(1, 2, 3), 2)") + sql(s"insert into $table values (array(1, 2), 5)") + sql(s"insert into $table values (array(), 0)") + sql(s"insert into $table values (null, 1)") + sql(s"insert into $table values (array(1, null, 3), 2)") + + val df = spark.table(table) + + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2))) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > col("threshold")))) + checkSparkAnswerAndOperator( + df.select( + exists(col("arr"), x => x > 0).as("any_positive"), + exists(col("arr"), x => x > 100).as("any_large"))) + } + } + + test("array_exists - DataFrame API with decimal") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1.50, 2.75, 3.25))") + sql(s"insert into $table values (array(0.10, 0.20))") + + val df = spark.table(table) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2.0))) + } + } + + test("array_exists - DataFrame API with date") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(date'2024-01-01', date'2024-06-15'))") + sql(s"insert into $table values (array(date'2023-01-01'))") + + val df = spark.table(table) + checkSparkAnswerAndOperator( + df.select(exists(col("arr"), x => x > lit("2024-03-01").cast("date")))) + } + } + + test("array_exists - fallback for unsupported element type") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(X'01', X'02'))") + + val df = spark.table(table) + checkSparkAnswerAndFallbackReason( + df.select(exists(col("arr"), x => x.isNotNull)), + "element type not supported") + } + } + + test("array_exists - fallback with UDF in lambda") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1, 2, 3))") + sql(s"insert into $table values (array(4, 5, 6))") + sql(s"insert into $table values (null)") + + val isEven = udf((x: Int) => x % 2 == 0) + + val df = spark.table(table) + checkSparkAnswerAndFallbackReason( + df.select(exists(col("arr"), x => isEven(x))), + "scalaudf is not supported") + } + } }