From e91b23eff98cafd9b7f2f1b0950a718021db43c6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 27 Feb 2026 06:50:52 -0700 Subject: [PATCH 1/3] feat: add array_exists with lambda expression support Add native support for `array_exists(arr, x -> predicate(x))` in SQL and DataFrame API. This is the first general-purpose lambda expression infrastructure, which can later be extended to support `array_filter`, `array_transform`, and `array_forall`. The lambda body is serialized as a regular expression tree where `NamedLambdaVariable` leaf nodes are serialized as `LambdaVariable` proto messages. On the Rust side, `ArrayExistsExpr` evaluates the lambda body vectorized over all elements in a single pass: it flattens list values, expands the batch with repeat indices, appends elements as a `__comet_lambda_var` column, evaluates once, and reduces per row with SQL three-valued logic semantics. Unsupported lambda bodies (e.g. containing UDFs) fall back to Spark. Closes #3149 --- .../core/src/execution/expressions/array.rs | 67 +++ native/core/src/execution/expressions/mod.rs | 1 + .../execution/planner/expression_registry.rs | 19 + native/proto/src/proto/expr.proto | 13 + .../src/array_funcs/array_exists.rs | 499 ++++++++++++++++++ native/spark-expr/src/array_funcs/mod.rs | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 2 + .../scala/org/apache/comet/serde/arrays.scala | 82 ++- .../expressions/array/array_exists.sql | 98 ++++ .../comet/CometArrayExpressionSuite.scala | 94 ++++ 10 files changed, 876 insertions(+), 1 deletion(-) create mode 100644 native/core/src/execution/expressions/array.rs create mode 100644 native/spark-expr/src/array_funcs/array_exists.rs create mode 100644 spark/src/test/resources/sql-tests/expressions/array/array_exists.sql diff --git a/native/core/src/execution/expressions/array.rs b/native/core/src/execution/expressions/array.rs new file mode 100644 index 0000000000..7e8921c8d8 --- /dev/null +++ b/native/core/src/execution/expressions/array.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Array expression builders + +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use datafusion::physical_expr::PhysicalExpr; +use datafusion_comet_proto::spark_expression::Expr; +use datafusion_comet_spark_expr::{ArrayExistsExpr, LambdaVariableExpr}; + +use crate::execution::operators::ExecutionError; +use crate::execution::planner::expression_registry::ExpressionBuilder; +use crate::execution::planner::PhysicalPlanner; +use crate::execution::serde::to_arrow_datatype; +use crate::extract_expr; + +pub struct ArrayExistsBuilder; + +impl ExpressionBuilder for ArrayExistsBuilder { + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, ArrayExists); + let array_expr = + planner.create_expr(expr.array.as_ref().unwrap(), Arc::clone(&input_schema))?; + let lambda_body = planner.create_expr(expr.lambda_body.as_ref().unwrap(), input_schema)?; + Ok(Arc::new(ArrayExistsExpr::new( + array_expr, + lambda_body, + expr.follow_three_valued_logic, + ))) + } +} + +pub struct LambdaVariableBuilder; + +impl ExpressionBuilder for LambdaVariableBuilder { + fn build( + &self, + spark_expr: &Expr, + _input_schema: SchemaRef, + _planner: &PhysicalPlanner, + ) -> Result, ExecutionError> { + let expr = extract_expr!(spark_expr, LambdaVariable); + let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + Ok(Arc::new(LambdaVariableExpr::new(data_type))) + } +} diff --git a/native/core/src/execution/expressions/mod.rs b/native/core/src/execution/expressions/mod.rs index 563d62e91b..02f50f1b7e 100644 --- a/native/core/src/execution/expressions/mod.rs +++ b/native/core/src/execution/expressions/mod.rs @@ -18,6 +18,7 @@ //! Native DataFusion expressions pub mod arithmetic; +pub mod array; pub mod bitwise; pub mod comparison; pub mod logical; diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs index 34aa3de179..8097eb94d1 100644 --- a/native/core/src/execution/planner/expression_registry.rs +++ b/native/core/src/execution/planner/expression_registry.rs @@ -103,6 +103,8 @@ pub enum ExpressionType { Randn, SparkPartitionId, MonotonicallyIncreasingId, + ArrayExists, + LambdaVariable, // Time functions Hour, @@ -184,6 +186,9 @@ impl ExpressionRegistry { // Register temporal expressions self.register_temporal_expressions(); + + // Register array expressions + self.register_array_expressions(); } /// Register arithmetic expression builders @@ -306,6 +311,18 @@ impl ExpressionRegistry { ); } + /// Register array expression builders + fn register_array_expressions(&mut self) { + use crate::execution::expressions::array::*; + + self.builders + .insert(ExpressionType::ArrayExists, Box::new(ArrayExistsBuilder)); + self.builders.insert( + ExpressionType::LambdaVariable, + Box::new(LambdaVariableBuilder), + ); + } + /// Extract expression type from Spark protobuf expression fn get_expression_type(spark_expr: &Expr) -> Result { match spark_expr.expr_struct.as_ref() { @@ -370,6 +387,8 @@ impl ExpressionRegistry { Some(ExprStruct::MonotonicallyIncreasingId(_)) => { Ok(ExpressionType::MonotonicallyIncreasingId) } + Some(ExprStruct::ArrayExists(_)) => Ok(ExpressionType::ArrayExists), + Some(ExprStruct::LambdaVariable(_)) => Ok(ExpressionType::LambdaVariable), Some(ExprStruct::Hour(_)) => Ok(ExpressionType::Hour), Some(ExprStruct::Minute(_)) => Ok(ExpressionType::Minute), diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 944505ba1c..ce97a7a27e 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -88,6 +88,8 @@ message Expr { UnixTimestamp unix_timestamp = 65; FromJson from_json = 66; ToCsv to_csv = 67; + ArrayExists array_exists = 68; + LambdaVariable lambda_variable = 69; } } @@ -440,3 +442,14 @@ message ArrayJoin { message Rand { int64 seed = 1; } + +message ArrayExists { + Expr array = 1; + Expr lambda_body = 2; + DataType element_type = 3; + bool follow_three_valued_logic = 4; +} + +message LambdaVariable { + DataType datatype = 1; +} diff --git a/native/spark-expr/src/array_funcs/array_exists.rs b/native/spark-expr/src/array_funcs/array_exists.rs new file mode 100644 index 0000000000..e45d0fb3ee --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_exists.rs @@ -0,0 +1,499 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, BooleanArray, ListArray}; +use arrow::compute::kernels::take::take; +use arrow::datatypes::{DataType, Field, Schema, UInt32Type}; +use arrow::record_batch::RecordBatch; +use datafusion::common::{DataFusionError, Result as DataFusionResult}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; +use std::sync::Arc; + +const LAMBDA_VAR_COLUMN: &str = "__comet_lambda_var"; + +/// Spark-compatible `array_exists(array, x -> predicate(x))`. +/// +/// Evaluates the lambda body vectorized over all elements in a single pass rather +/// than per-element to avoid repeated batch construction overhead. +#[derive(Debug, Eq)] +pub struct ArrayExistsExpr { + array_expr: Arc, + lambda_body: Arc, + follow_three_valued_logic: bool, +} + +impl Hash for ArrayExistsExpr { + fn hash(&self, state: &mut H) { + self.array_expr.hash(state); + self.lambda_body.hash(state); + self.follow_three_valued_logic.hash(state); + } +} + +impl PartialEq for ArrayExistsExpr { + fn eq(&self, other: &Self) -> bool { + self.array_expr.eq(&other.array_expr) + && self.lambda_body.eq(&other.lambda_body) + && self + .follow_three_valued_logic + .eq(&other.follow_three_valued_logic) + } +} + +impl ArrayExistsExpr { + pub fn new( + array_expr: Arc, + lambda_body: Arc, + follow_three_valued_logic: bool, + ) -> Self { + Self { + array_expr, + lambda_body, + follow_three_valued_logic, + } + } +} + +impl PhysicalExpr for ArrayExistsExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn nullable(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let num_rows = batch.num_rows(); + + // Evaluate the array expression + let array_value = self.array_expr.evaluate(batch)?.into_array(num_rows)?; + + let list_array = array_value + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("ArrayExists expects a ListArray input".to_string()) + })?; + + let offsets = list_array.offsets(); + let values = list_array.values(); + let total_elements = values.len(); + + if total_elements == 0 { + let mut result_builder = BooleanArray::builder(num_rows); + for row in 0..num_rows { + if list_array.is_null(row) { + result_builder.append_null(); + } else { + result_builder.append_value(false); + } + } + return Ok(ColumnarValue::Array(Arc::new(result_builder.finish()))); + } + + let mut repeat_indices = Vec::with_capacity(total_elements); + for row in 0..num_rows { + let start = offsets[row] as usize; + let end = offsets[row + 1] as usize; + for _ in start..end { + repeat_indices.push(row as u32); + } + } + + let repeat_indices_array = arrow::array::PrimitiveArray::::from(repeat_indices); + + let mut expanded_columns: Vec = Vec::with_capacity(batch.num_columns() + 1); + let mut expanded_fields: Vec> = Vec::with_capacity(batch.num_columns() + 1); + + for (i, col) in batch.columns().iter().enumerate() { + let expanded = take(col.as_ref(), &repeat_indices_array, None)?; + expanded_columns.push(expanded); + expanded_fields.push(Arc::new(batch.schema().field(i).clone())); + } + + let element_field = Arc::new(Field::new( + LAMBDA_VAR_COLUMN, + values.data_type().clone(), + true, + )); + expanded_columns.push(Arc::clone(values)); + expanded_fields.push(element_field); + + let expanded_schema = Arc::new(Schema::new(expanded_fields)); + let expanded_batch = RecordBatch::try_new(expanded_schema, expanded_columns)?; + + let body_result = self + .lambda_body + .evaluate(&expanded_batch)? + .into_array(total_elements)?; + + let body_booleans = body_result + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "ArrayExists lambda body must return BooleanArray".to_string(), + ) + })?; + + let mut result_builder = BooleanArray::builder(num_rows); + for row in 0..num_rows { + if list_array.is_null(row) { + result_builder.append_null(); + continue; + } + + let start = offsets[row] as usize; + let end = offsets[row + 1] as usize; + + if start == end { + result_builder.append_value(false); + continue; + } + + let mut found_true = false; + let mut found_null = false; + + for idx in start..end { + if body_booleans.is_null(idx) { + found_null = true; + } else if body_booleans.value(idx) { + found_true = true; + break; + } + } + + if found_true { + result_builder.append_value(true); + } else if found_null && self.follow_three_valued_logic { + result_builder.append_null(); + } else { + result_builder.append_value(false); + } + } + + Ok(ColumnarValue::Array(Arc::new(result_builder.finish()))) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.array_expr, &self.lambda_body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + match children.len() { + 2 => Ok(Arc::new(ArrayExistsExpr::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.follow_three_valued_logic, + ))), + _ => Err(DataFusionError::Internal( + "ArrayExistsExpr should have exactly two children".to_string(), + )), + } + } +} + +impl Display for ArrayExistsExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ArrayExists [array: {:?}, lambda_body: {:?}]", + self.array_expr, self.lambda_body + ) + } +} + +#[derive(Debug, Eq)] +pub struct LambdaVariableExpr { + data_type: DataType, +} + +impl Hash for LambdaVariableExpr { + fn hash(&self, state: &mut H) { + self.data_type.hash(state); + } +} + +impl PartialEq for LambdaVariableExpr { + fn eq(&self, other: &Self) -> bool { + self.data_type == other.data_type + } +} + +impl LambdaVariableExpr { + pub fn new(data_type: DataType) -> Self { + Self { data_type } + } +} + +impl PhysicalExpr for LambdaVariableExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let schema = batch.schema(); + let idx = schema.index_of(LAMBDA_VAR_COLUMN).map_err(|_| { + DataFusionError::Internal(format!( + "Lambda variable column '{}' not found in batch schema", + LAMBDA_VAR_COLUMN + )) + })?; + Ok(ColumnarValue::Array(Arc::clone(batch.column(idx)))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + if children.is_empty() { + Ok(Arc::new(LambdaVariableExpr::new(self.data_type.clone()))) + } else { + Err(DataFusionError::Internal( + "LambdaVariableExpr should have no children".to_string(), + )) + } + } +} + +impl Display for LambdaVariableExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "LambdaVariable({})", self.data_type) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::array::ListArray; + use arrow::datatypes::Int32Type; + use datafusion::physical_expr::expressions::{Column, Literal}; + use datafusion::{ + common::ScalarValue, logical_expr::Operator, physical_expr::expressions::BinaryExpr, + }; + + fn make_lambda_var_expr() -> Arc { + Arc::new(LambdaVariableExpr::new(DataType::Int32)) + } + + fn make_gt_predicate(threshold: i32) -> Arc { + Arc::new(BinaryExpr::new( + make_lambda_var_expr(), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(threshold)))), + )) + } + + #[test] + fn test_basic_exists() -> DataFusionResult<()> { + // exists(array(1, 2, 3), x -> x > 2) = true + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); + assert!(!bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_empty_array() -> DataFusionResult<()> { + // exists(array(), x -> x > 0) = false + let list = + ListArray::from_iter_primitive::(vec![ + Some(Vec::>::new()), + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(0); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(!bools.value(0)); + assert!(!bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_null_array() -> DataFusionResult<()> { + // exists(null, x -> x > 0) = null + let list = + ListArray::from_iter_primitive::(vec![None::>>]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(0); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.is_null(0)); + Ok(()) + } + + #[test] + fn test_three_valued_logic() -> DataFusionResult<()> { + // exists(array(1, null, 3), x -> x > 5) = null (three-valued logic) + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(5); + + // With three-valued logic: result should be null + let expr = ArrayExistsExpr::new(Arc::clone(&array_expr), Arc::clone(&lambda_body), true); + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.is_null(0)); + + // Without three-valued logic: result should be false + let expr2 = ArrayExistsExpr::new(array_expr, lambda_body, false); + let result2 = expr2.evaluate(&batch)?.into_array(1)?; + let bools2 = result2.as_any().downcast_ref::().unwrap(); + assert!(!bools2.is_null(0)); + assert!(!bools2.value(0)); + Ok(()) + } + + #[test] + fn test_null_elements_with_match() -> DataFusionResult<()> { + // exists(array(1, null, 3), x -> x > 2) = true (because 3 > 2) + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(3), + ])]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(1)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(!bools.is_null(0)); + assert!(bools.value(0)); + Ok(()) + } + + #[test] + fn test_multiple_rows() -> DataFusionResult<()> { + // Row 0: [1, 2, 3] -> x > 2 -> true + // Row 1: [1, 2] -> x > 2 -> false + // Row 2: null -> null + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(1), Some(2)]), + None, + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "arr", + list.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(2); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(3)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); + assert!(!bools.value(1)); + assert!(bools.is_null(2)); + Ok(()) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 3ef50a252f..64d8ac934f 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -15,11 +15,13 @@ // specific language governing permissions and limitations // under the License. +mod array_exists; mod array_insert; mod get_array_struct_fields; mod list_extract; mod size; +pub use array_exists::{ArrayExistsExpr, LambdaVariableExpr}; pub use array_insert::ArrayInsert; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9d13ccd9ed..6c4e606a52 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -50,6 +50,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArrayContains] -> CometArrayContains, classOf[ArrayDistinct] -> CometArrayDistinct, classOf[ArrayExcept] -> CometArrayExcept, + classOf[ArrayExists] -> CometArrayExists, classOf[ArrayFilter] -> CometArrayFilter, classOf[ArrayInsert] -> CometArrayInsert, classOf[ArrayIntersect] -> CometArrayIntersect, @@ -64,6 +65,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ElementAt] -> CometElementAt, classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, + classOf[NamedLambdaVariable] -> CometNamedLambdaVariable, classOf[Size] -> CometSize) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b7ebb9ba7b..70e8a9153b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayExists, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, LambdaFunction, Literal, NamedLambdaVariable, Reverse, Size} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -603,6 +603,86 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { } } +object CometArrayExists extends CometExpressionSerde[ArrayExists] { + + override def getSupportLevel(expr: ArrayExists): SupportLevel = { + val elementType = expr.argument.dataType.asInstanceOf[ArrayType].elementType + elementType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | + _: DecimalType | DateType | TimestampType | StringType => + Compatible() + case _ => Unsupported(Some(s"element type not supported: $elementType")) + } + } + + override def convert( + expr: ArrayExists, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayExprProto = exprToProto(expr.argument, inputs, binding) + if (arrayExprProto.isEmpty) { + withInfo(expr, expr.argument) + return None + } + + expr.function match { + case LambdaFunction(body, Seq(elementVar: NamedLambdaVariable), _) => + val bodyProto = exprToProto(body, inputs, binding) + if (bodyProto.isEmpty) { + withInfo(expr, body) + return None + } + + val elementType = elementVar.dataType + val elementTypeProto = serializeDataType(elementType) + if (elementTypeProto.isEmpty) { + withInfo(expr, s"Cannot serialize element type: $elementType") + return None + } + + val arrayExistsBuilder = ExprOuterClass.ArrayExists + .newBuilder() + .setArray(arrayExprProto.get) + .setLambdaBody(bodyProto.get) + .setElementType(elementTypeProto.get) + .setFollowThreeValuedLogic(expr.followThreeValuedLogic) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setArrayExists(arrayExistsBuilder) + .build()) + + case other => + withInfo(expr, s"Unsupported lambda function form: $other") + None + } + } +} + +object CometNamedLambdaVariable extends CometExpressionSerde[NamedLambdaVariable] { + override def convert( + expr: NamedLambdaVariable, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val dataTypeProto = serializeDataType(expr.dataType) + if (dataTypeProto.isEmpty) { + withInfo(expr, s"Cannot serialize data type: ${expr.dataType}") + return None + } + + val lambdaVarBuilder = ExprOuterClass.LambdaVariable + .newBuilder() + .setDatatype(dataTypeProto.get) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setLambdaVariable(lambdaVarBuilder) + .build()) + } +} + object CometSize extends CometExpressionSerde[Size] { override def getSupportLevel(expr: Size): SupportLevel = { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql b/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql new file mode 100644 index 0000000000..3ecbe7eaf0 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_exists.sql @@ -0,0 +1,98 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +statement +CREATE TABLE test_array_exists(arr_int array, arr_str array, arr_double array, arr_bool array, arr_long array, threshold int) USING parquet + +statement +INSERT INTO test_array_exists VALUES (array(1, 2, 3), array('a', 'bb', 'ccc'), array(1.5, 2.5, 3.5), array(false, false, true), array(100, 200, 300), 2), (array(1, 2), array('a', 'b'), array(0.5, 1.5), array(false, false), array(10, 20), 5), (array(), array(), array(), array(), array(), 0), (NULL, NULL, NULL, NULL, NULL, 1), (array(1, NULL, 3), array('a', NULL, 'ccc'), array(1.0, NULL, 3.0), array(true, NULL, false), array(100, NULL, 300), 2) + +-- basic: element satisfies predicate +query +SELECT exists(arr_int, x -> x > 2) FROM test_array_exists + +-- no match +query +SELECT exists(arr_int, x -> x > 100) FROM test_array_exists + +-- empty array returns false +query +SELECT exists(arr_int, x -> x > 0) FROM test_array_exists + +-- null array returns null +query +SELECT exists(arr_int, x -> x > 0) FROM test_array_exists WHERE arr_int IS NULL + +-- predicate referencing outer column +query +SELECT exists(arr_int, x -> x > threshold) FROM test_array_exists + +-- three-valued logic: null elements with no match -> null +query +SELECT exists(arr_int, x -> x > 5) FROM test_array_exists + +-- null elements but match exists -> true +query +SELECT exists(arr_int, x -> x > 2) FROM test_array_exists + +-- string type +query +SELECT exists(arr_str, x -> length(x) > 2) FROM test_array_exists + +-- double type +query +SELECT exists(arr_double, x -> x > 2.0) FROM test_array_exists + +-- boolean type +query +SELECT exists(arr_bool, x -> x) FROM test_array_exists + +-- long type +query +SELECT exists(arr_long, x -> x > 250) FROM test_array_exists + +-- literal arrays +query +SELECT exists(array(1, 2, 3), x -> x > 2) + +query +SELECT exists(array(1, 2, 3), x -> x > 5) + +-- empty literal array has NullType element type, which is unsupported +query spark_answer_only +SELECT exists(array(), x -> cast(x as int) > 0) + +query +SELECT exists(cast(NULL as array), x -> x > 0) + +-- null elements in literal array with three-valued logic +query +SELECT exists(array(1, NULL, 3), x -> x > 5) + +-- null elements in literal array with match +query +SELECT exists(array(1, NULL, 3), x -> x > 2) + +-- complex predicate +query +SELECT exists(arr_int, x -> x > 1 AND x < 3) FROM test_array_exists + +-- predicate with modulo +query +SELECT exists(arr_int, x -> x % 2 = 0) FROM test_array_exists diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index b22d0f72db..3bb3c2b6d1 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -922,4 +922,98 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + test("array_exists - basic") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array, threshold int) using parquet") + sql(s"insert into $table values (array(1, 2, 3), 2)") + sql(s"insert into $table values (array(1, 2), 5)") + sql(s"insert into $table values (array(), 0)") + sql(s"insert into $table values (null, 1)") + + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 2) from $table")) + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > threshold) from $table")) + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 0) from $table")) + } + } + + test("array_exists - null elements and three-valued logic") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1, null, 3))") + sql(s"insert into $table values (array(null, null))") + sql(s"insert into $table values (array(1, 2, 3))") + + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 5) from $table")) + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 2) from $table")) + } + } + + test("array_exists - various types") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array('a', 'bb', 'ccc'))") + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> length(x) > 2) from $table")) + } + + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1.5, 2.5, 3.5))") + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 2.0) from $table")) + } + + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(false, false, true))") + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x) from $table")) + } + + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(100, 200, 300))") + checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 250) from $table")) + } + } + + test("array_exists - DataFrame API") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array, threshold int) using parquet") + sql(s"insert into $table values (array(1, 2, 3), 2)") + sql(s"insert into $table values (array(1, 2), 5)") + sql(s"insert into $table values (array(), 0)") + sql(s"insert into $table values (null, 1)") + sql(s"insert into $table values (array(1, null, 3), 2)") + + val df = spark.table(table) + + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2))) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > col("threshold")))) + checkSparkAnswerAndOperator( + df.select( + exists(col("arr"), x => x > 0).as("any_positive"), + exists(col("arr"), x => x > 100).as("any_large"))) + } + } + + test("array_exists - fallback with UDF in lambda") { + val table = "t1" + withTable(table) { + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1, 2, 3))") + sql(s"insert into $table values (array(4, 5, 6))") + sql(s"insert into $table values (null)") + + val isEven = udf((x: Int) => x % 2 == 0) + + val df = spark.table(table) + // UDF in lambda body cannot be serialized to native code + checkSparkAnswerAndFallbackReason( + df.select(exists(col("arr"), x => isEven(x))), + "scalaudf is not supported") + } + } } From a787730daa07a7f767ffe5a0f558262b25c8d8cc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 27 Feb 2026 07:24:08 -0700 Subject: [PATCH 2/3] fix: address review feedback for array_exists - Remove unused element_type proto field from ArrayExists - Add LargeListArray support via decompose_list helper - Use column index instead of name for lambda variable lookup - Add TimestampNTZType to supported element types - Restore CometNamedLambdaVariable as standalone serde object - Remove SQL-based Scala tests (covered by SQL file tests) - Add DataFrame tests for decimal and date element types - Add negative test for unsupported element type fallback - Add multi-column batch Rust unit test --- native/proto/src/proto/expr.proto | 7 +- .../src/array_funcs/array_exists.rs | 103 +++++++++++++----- .../scala/org/apache/comet/serde/arrays.scala | 12 +- .../comet/CometArrayExpressionSuite.scala | 76 +++++-------- 4 files changed, 112 insertions(+), 86 deletions(-) diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index ce97a7a27e..d3420fefd4 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -446,10 +446,13 @@ message Rand { message ArrayExists { Expr array = 1; Expr lambda_body = 2; - DataType element_type = 3; - bool follow_three_valued_logic = 4; + bool follow_three_valued_logic = 3; } +// Currently only supports a single lambda variable per expression. The variable +// is resolved by column index (always the last column in the expanded batch +// constructed by ArrayExistsExpr). Extending to multi-argument lambdas +// (e.g. transform(array, (x, i) -> ...)) would require adding an identifier. message LambdaVariable { DataType datatype = 1; } diff --git a/native/spark-expr/src/array_funcs/array_exists.rs b/native/spark-expr/src/array_funcs/array_exists.rs index e45d0fb3ee..c0be034be2 100644 --- a/native/spark-expr/src/array_funcs/array_exists.rs +++ b/native/spark-expr/src/array_funcs/array_exists.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, BooleanArray, ListArray}; +use arrow::array::{Array, ArrayRef, BooleanArray, LargeListArray, ListArray}; +use arrow::buffer::NullBuffer; use arrow::compute::kernels::take::take; use arrow::datatypes::{DataType, Field, Schema, UInt32Type}; use arrow::record_batch::RecordBatch; @@ -29,6 +30,39 @@ use std::sync::Arc; const LAMBDA_VAR_COLUMN: &str = "__comet_lambda_var"; +/// Decomposed list array: offsets as usize, values, and optional null buffer. +struct ListComponents { + offsets: Vec, + values: ArrayRef, + nulls: Option, +} + +impl ListComponents { + fn is_null(&self, row: usize) -> bool { + self.nulls.as_ref().is_some_and(|n| n.is_null(row)) + } +} + +fn decompose_list(array: &dyn Array) -> DataFusionResult { + if let Some(list) = array.as_any().downcast_ref::() { + Ok(ListComponents { + offsets: list.offsets().iter().map(|&o| o as usize).collect(), + values: Arc::clone(list.values()), + nulls: list.nulls().cloned(), + }) + } else if let Some(large) = array.as_any().downcast_ref::() { + Ok(ListComponents { + offsets: large.offsets().iter().map(|&o| o as usize).collect(), + values: Arc::clone(large.values()), + nulls: large.nulls().cloned(), + }) + } else { + Err(DataFusionError::Internal( + "ArrayExists expects a ListArray or LargeListArray input".to_string(), + )) + } +} + /// Spark-compatible `array_exists(array, x -> predicate(x))`. /// /// Evaluates the lambda body vectorized over all elements in a single pass rather @@ -92,24 +126,14 @@ impl PhysicalExpr for ArrayExistsExpr { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let num_rows = batch.num_rows(); - // Evaluate the array expression let array_value = self.array_expr.evaluate(batch)?.into_array(num_rows)?; - - let list_array = array_value - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal("ArrayExists expects a ListArray input".to_string()) - })?; - - let offsets = list_array.offsets(); - let values = list_array.values(); - let total_elements = values.len(); + let list = decompose_list(array_value.as_ref())?; + let total_elements = list.values.len(); if total_elements == 0 { let mut result_builder = BooleanArray::builder(num_rows); for row in 0..num_rows { - if list_array.is_null(row) { + if list.is_null(row) { result_builder.append_null(); } else { result_builder.append_value(false); @@ -120,8 +144,8 @@ impl PhysicalExpr for ArrayExistsExpr { let mut repeat_indices = Vec::with_capacity(total_elements); for row in 0..num_rows { - let start = offsets[row] as usize; - let end = offsets[row + 1] as usize; + let start = list.offsets[row]; + let end = list.offsets[row + 1]; for _ in start..end { repeat_indices.push(row as u32); } @@ -140,10 +164,10 @@ impl PhysicalExpr for ArrayExistsExpr { let element_field = Arc::new(Field::new( LAMBDA_VAR_COLUMN, - values.data_type().clone(), + list.values.data_type().clone(), true, )); - expanded_columns.push(Arc::clone(values)); + expanded_columns.push(Arc::clone(&list.values)); expanded_fields.push(element_field); let expanded_schema = Arc::new(Schema::new(expanded_fields)); @@ -165,13 +189,13 @@ impl PhysicalExpr for ArrayExistsExpr { let mut result_builder = BooleanArray::builder(num_rows); for row in 0..num_rows { - if list_array.is_null(row) { + if list.is_null(row) { result_builder.append_null(); continue; } - let start = offsets[row] as usize; - let end = offsets[row + 1] as usize; + let start = list.offsets[row]; + let end = list.offsets[row + 1]; if start == end { result_builder.append_value(false); @@ -274,13 +298,8 @@ impl PhysicalExpr for LambdaVariableExpr { } fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { - let schema = batch.schema(); - let idx = schema.index_of(LAMBDA_VAR_COLUMN).map_err(|_| { - DataFusionError::Internal(format!( - "Lambda variable column '{}' not found in batch schema", - LAMBDA_VAR_COLUMN - )) - })?; + // The lambda variable is always the last column, appended by ArrayExistsExpr + let idx = batch.num_columns() - 1; Ok(ColumnarValue::Array(Arc::clone(batch.column(idx)))) } @@ -496,4 +515,32 @@ mod test { assert!(bools.is_null(2)); Ok(()) } + + #[test] + fn test_multi_column_batch() -> DataFusionResult<()> { + // Verify batch expansion works correctly with additional columns + use arrow::array::Int32Array; + + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(10), Some(20)]), + Some(vec![Some(5)]), + ]); + let extra_col = Int32Array::from(vec![100, 200]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("arr", list.data_type().clone(), true), + Field::new("extra", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(list), Arc::new(extra_col)])?; + + let array_expr: Arc = Arc::new(Column::new("arr", 0)); + let lambda_body = make_gt_predicate(15); + let expr = ArrayExistsExpr::new(array_expr, lambda_body, true); + + let result = expr.evaluate(&batch)?.into_array(2)?; + let bools = result.as_any().downcast_ref::().unwrap(); + assert!(bools.value(0)); // [10, 20] has 20 > 15 + assert!(!bools.value(1)); // [5] has no element > 15 + Ok(()) + } } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 70e8a9153b..c95d9d7ef4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -609,7 +609,7 @@ object CometArrayExists extends CometExpressionSerde[ArrayExists] { val elementType = expr.argument.dataType.asInstanceOf[ArrayType].elementType elementType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | - _: DecimalType | DateType | TimestampType | StringType => + _: DecimalType | DateType | TimestampType | TimestampNTZType | StringType => Compatible() case _ => Unsupported(Some(s"element type not supported: $elementType")) } @@ -633,18 +633,10 @@ object CometArrayExists extends CometExpressionSerde[ArrayExists] { return None } - val elementType = elementVar.dataType - val elementTypeProto = serializeDataType(elementType) - if (elementTypeProto.isEmpty) { - withInfo(expr, s"Cannot serialize element type: $elementType") - return None - } - val arrayExistsBuilder = ExprOuterClass.ArrayExists .newBuilder() .setArray(arrayExprProto.get) .setLambdaBody(bodyProto.get) - .setElementType(elementTypeProto.get) .setFollowThreeValuedLogic(expr.followThreeValuedLogic) Some( @@ -658,9 +650,11 @@ object CometArrayExists extends CometExpressionSerde[ArrayExists] { None } } + } object CometNamedLambdaVariable extends CometExpressionSerde[NamedLambdaVariable] { + override def convert( expr: NamedLambdaVariable, inputs: Seq[Attribute], diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 3bb3c2b6d1..9bea8d25ef 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -923,7 +923,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } - test("array_exists - basic") { + test("array_exists - DataFrame API") { val table = "t1" withTable(table) { sql(s"create table $table(arr array, threshold int) using parquet") @@ -931,71 +931,54 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp sql(s"insert into $table values (array(1, 2), 5)") sql(s"insert into $table values (array(), 0)") sql(s"insert into $table values (null, 1)") + sql(s"insert into $table values (array(1, null, 3), 2)") - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 2) from $table")) - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > threshold) from $table")) - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 0) from $table")) + val df = spark.table(table) + + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2))) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > col("threshold")))) + checkSparkAnswerAndOperator( + df.select( + exists(col("arr"), x => x > 0).as("any_positive"), + exists(col("arr"), x => x > 100).as("any_large"))) } } - test("array_exists - null elements and three-valued logic") { + test("array_exists - DataFrame API with decimal") { val table = "t1" withTable(table) { - sql(s"create table $table(arr array) using parquet") - sql(s"insert into $table values (array(1, null, 3))") - sql(s"insert into $table values (array(null, null))") - sql(s"insert into $table values (array(1, 2, 3))") + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(1.50, 2.75, 3.25))") + sql(s"insert into $table values (array(0.10, 0.20))") - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 5) from $table")) - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 2) from $table")) + val df = spark.table(table) + checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2.0))) } } - test("array_exists - various types") { + test("array_exists - DataFrame API with date") { val table = "t1" withTable(table) { - sql(s"create table $table(arr array) using parquet") - sql(s"insert into $table values (array('a', 'bb', 'ccc'))") - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> length(x) > 2) from $table")) - } - - withTable(table) { - sql(s"create table $table(arr array) using parquet") - sql(s"insert into $table values (array(1.5, 2.5, 3.5))") - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 2.0) from $table")) - } - - withTable(table) { - sql(s"create table $table(arr array) using parquet") - sql(s"insert into $table values (array(false, false, true))") - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x) from $table")) - } + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(date'2024-01-01', date'2024-06-15'))") + sql(s"insert into $table values (array(date'2023-01-01'))") - withTable(table) { - sql(s"create table $table(arr array) using parquet") - sql(s"insert into $table values (array(100, 200, 300))") - checkSparkAnswerAndOperator(sql(s"select exists(arr, x -> x > 250) from $table")) + val df = spark.table(table) + checkSparkAnswerAndOperator( + df.select(exists(col("arr"), x => x > lit("2024-03-01").cast("date")))) } } - test("array_exists - DataFrame API") { + test("array_exists - fallback for unsupported element type") { val table = "t1" withTable(table) { - sql(s"create table $table(arr array, threshold int) using parquet") - sql(s"insert into $table values (array(1, 2, 3), 2)") - sql(s"insert into $table values (array(1, 2), 5)") - sql(s"insert into $table values (array(), 0)") - sql(s"insert into $table values (null, 1)") - sql(s"insert into $table values (array(1, null, 3), 2)") + sql(s"create table $table(arr array) using parquet") + sql(s"insert into $table values (array(X'01', X'02'))") val df = spark.table(table) - - checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > 2))) - checkSparkAnswerAndOperator(df.select(exists(col("arr"), x => x > col("threshold")))) - checkSparkAnswerAndOperator( - df.select( - exists(col("arr"), x => x > 0).as("any_positive"), - exists(col("arr"), x => x > 100).as("any_large"))) + checkSparkAnswerAndFallbackReason( + df.select(exists(col("arr"), x => x.isNotNull)), + "element type not supported") } } @@ -1010,7 +993,6 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp val isEven = udf((x: Int) => x % 2 == 0) val df = spark.table(table) - // UDF in lambda body cannot be serialized to native code checkSparkAnswerAndFallbackReason( df.select(exists(col("arr"), x => isEven(x))), "scalaudf is not supported") From 4e10e4b369e3bc4b5702ae740ee4c80e4c0f7b3c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 27 Feb 2026 07:34:41 -0700 Subject: [PATCH 3/3] fix: remove unused variable binding in lambda pattern match --- spark/src/main/scala/org/apache/comet/serde/arrays.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index c95d9d7ef4..240bad577e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -626,7 +626,7 @@ object CometArrayExists extends CometExpressionSerde[ArrayExists] { } expr.function match { - case LambdaFunction(body, Seq(elementVar: NamedLambdaVariable), _) => + case LambdaFunction(body, Seq(_: NamedLambdaVariable), _) => val bodyProto = exprToProto(body, inputs, binding) if (bodyProto.isEmpty) { withInfo(expr, body)