From 125b0b1acddbf67c173b827f3892f4b6a05a07df Mon Sep 17 00:00:00 2001 From: noroshi Date: Wed, 4 Mar 2026 09:11:52 +0000 Subject: [PATCH] Use datafusion-spark SparkArrayContains for Spark three-valued NULL semantics Replace the CASE WHEN wrapper around array_has with a direct call to datafusion-spark's SparkArrayContains, which handles Spark's three-valued NULL semantics natively: when an element is not found and the array contains NULL elements, the result is NULL instead of false. --- native/core/src/execution/jni_api.rs | 2 ++ .../scala/org/apache/comet/serde/arrays.scala | 35 +++---------------- 2 files changed, 7 insertions(+), 30 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1030e30aaf..0c61b3ee9d 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -55,6 +55,7 @@ use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::math::hex::SparkHex; use datafusion_spark::function::math::width_bucket::SparkWidthBucket; use datafusion_spark::function::string::char::CharFunc; +use datafusion_spark::function::array::array_contains::SparkArrayContains; use datafusion_spark::function::string::concat::SparkConcat; use datafusion_spark::function::string::space::SparkSpace; use futures::poll; @@ -387,6 +388,7 @@ fn prepare_datafusion_session_context( // register UDFs from datafusion-spark crate fn register_datafusion_spark_function(session_ctx: &SessionContext) { + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkArrayContains::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default())); diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b7ebb9ba7b..8b096b7278 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -133,36 +133,11 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] { val arrayExprProto = exprToProto(expr.children.head, inputs, binding) val keyExprProto = exprToProto(expr.children(1), inputs, binding) - val arrayContainsScalarExpr = - scalarFunctionExprToProto("array_has", arrayExprProto, keyExprProto) - - // Handle NULL array input - return NULL if array is NULL (matching Spark's behavior) - val isNotNullExpr = createUnaryExpr( - expr, - expr.children.head, - inputs, - binding, - (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) - - val nullLiteralProto = exprToProto(Literal(null, BooleanType), Seq.empty) - - if (arrayContainsScalarExpr.isDefined && isNotNullExpr.isDefined && - nullLiteralProto.isDefined) { - val caseWhenExpr = ExprOuterClass.CaseWhen - .newBuilder() - .addWhen(isNotNullExpr.get) - .addThen(arrayContainsScalarExpr.get) - .setElseExpr(nullLiteralProto.get) - .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setCaseWhen(caseWhenExpr) - .build()) - } else { - withInfo(expr, expr.children: _*) - None - } + // Delegates to datafusion-spark's SparkArrayContains which handles + // Spark's three-valued NULL semantics natively (no CASE WHEN needed). + val arrayContainsExpr = + scalarFunctionExprToProto("array_contains", arrayExprProto, keyExprProto) + optExprWithInfo(arrayContainsExpr, expr) } }