From c1b9f7028381946dba51c31af54fd628070710d9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 12 Mar 2026 06:40:44 -0600 Subject: [PATCH] fix: implement Spark-compatible null handling for arrays_overlap Replace DataFusion's array_has_any (which treats NULL == NULL) with a custom implementation that follows Spark's three-valued logic: - true when arrays share a common non-null element - null when no common non-null elements but either array has nulls - false when no common elements and no nulls Closes #3645 --- .../source/user-guide/latest/compatibility.md | 12 + docs/source/user-guide/latest/expressions.md | 2 +- .../src/array_funcs/arrays_overlap.rs | 411 ++++++++++++++++++ native/spark-expr/src/array_funcs/mod.rs | 2 + native/spark-expr/src/comet_scalar_funcs.rs | 4 + .../expressions/array/arrays_overlap.sql | 4 +- 6 files changed, 432 insertions(+), 3 deletions(-) create mode 100644 native/spark-expr/src/array_funcs/arrays_overlap.rs diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 21695bdf57..a27c7c12aa 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -58,6 +58,18 @@ Expressions that are not 100% Spark-compatible will fall back to Spark by defaul `spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting. +## Array Functions + +### ArraysOverlap + +Comet's `arrays_overlap` implementation follows Spark's null handling semantics: when no common non-null elements +exist but either array contains null elements, the result is `null` rather than `false`. This matches Spark's +three-valued logic where `arrays_overlap(array(1, null), array(null, 2))` returns `null`. + +Comet currently uses `ScalarValue`-based comparison for complex element types (structs, nested arrays), which may +have subtle differences from Spark's equality semantics for these types. Primitive and string element types use +native comparisons that match Spark. + ## Regular Expressions Comet uses the Rust regexp crate for evaluating regular expressions, and this has different behavior from Java's diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 0339cd2a3e..794efa29ae 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -245,7 +245,7 @@ Comet supports using the following aggregate functions within window contexts wi | ArrayRemove | Yes | | | ArrayRepeat | No | | | ArrayUnion | No | Behaves differently than spark. Comet sorts the input arrays before performing the union, while Spark preserves the order of the first array and appends unique elements from the second. | -| ArraysOverlap | No | | +| ArraysOverlap | No | See [ArraysOverlap](compatibility.md#arraysoverlap) in the compatibility guide. | | CreateArray | Yes | | | ElementAt | Yes | Input must be an array. Map inputs are not supported. | | Flatten | Yes | | diff --git a/native/spark-expr/src/array_funcs/arrays_overlap.rs b/native/spark-expr/src/array_funcs/arrays_overlap.rs new file mode 100644 index 0000000000..507f3dda51 --- /dev/null +++ b/native/spark-expr/src/array_funcs/arrays_overlap.rs @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowPrimitiveType, DataType}; +use datafusion::common::ScalarValue; +use datafusion::common::{ + cast::{as_large_list_array, as_list_array}, + DataFusionError, Result as DataFusionResult, +}; +use datafusion::logical_expr::ColumnarValue; +use std::collections::HashSet; +use std::hash::Hash; +use std::ops::Range; +use std::sync::Arc; + +/// Spark-compatible arrays_overlap. +/// +/// Semantics (matching Spark): +/// - If either input array is null -> null +/// - If the arrays share a common **non-null** element -> true +/// - If no common non-null elements exist, but either array contains a null element -> null +/// - Otherwise -> false +pub fn spark_arrays_overlap(args: &[ColumnarValue]) -> DataFusionResult { + if args.len() != 2 { + return Err(DataFusionError::Internal( + "arrays_overlap requires exactly 2 arguments".to_string(), + )); + } + + let len = match (&args[0], &args[1]) { + (ColumnarValue::Array(a), _) => a.len(), + (_, ColumnarValue::Array(a)) => a.len(), + (ColumnarValue::Scalar(_), ColumnarValue::Scalar(_)) => 1, + }; + + let left = args[0].clone().into_array(len)?; + let right = args[1].clone().into_array(len)?; + + let result = match left.data_type() { + DataType::List(_) => { + let left_list = as_list_array(&left)?; + let right_list = as_list_array(&right)?; + arrays_overlap_inner(left_list, right_list)? + } + DataType::LargeList(_) => { + let left_list = as_large_list_array(&left)?; + let right_list = as_large_list_array(&right)?; + arrays_overlap_inner(left_list, right_list)? + } + dt => { + return Err(DataFusionError::Internal(format!( + "arrays_overlap expected List or LargeList, got {dt:?}" + ))) + } + }; + + Ok(ColumnarValue::Array(Arc::new(result))) +} + +/// Check overlap for a single row using a pre-populated set and a probe function. +/// This unifies the three-valued null logic across all type specializations. +fn check_overlap( + set: &HashSet, + left_has_null: bool, + right_values_iter: impl Iterator>, +) -> Option { + let mut right_has_null = false; + for item in right_values_iter { + match item { + None => right_has_null = true, + Some(v) => { + if set.contains(&v) { + return Some(true); + } + } + } + } + if left_has_null || right_has_null { + None + } else { + Some(false) + } +} + +fn arrays_overlap_inner( + left: &GenericListArray, + right: &GenericListArray, +) -> DataFusionResult { + let len = left.len(); + let left_values = left.values(); + let right_values = right.values(); + let element_type = left_values.data_type(); + + let mut builder = BooleanArray::builder(len); + + // Dispatch on element type once per batch, then loop over rows inside + // each specialization. This avoids per-row type matching and downcasting. + // Float types are excluded because f32/f64 do not implement Eq + Hash; + // they fall through to the ScalarValue path which uses total ordering. + macro_rules! dispatch_and_loop { + (primitive: $($dt:ident => $at:ty),+ $(,)?) => { + match element_type { + $(DataType::$dt => { + let left_arr = left_values.as_any().downcast_ref::>().unwrap(); + let right_arr = right_values.as_any().downcast_ref::>().unwrap(); + let mut set = HashSet::new(); + for i in 0..len { + if let Some((l_range, r_range)) = row_ranges(left, right, i) { + set.clear(); + let left_has_null = populate_set_primitive(left_arr, l_range, &mut set); + let result = check_overlap(&set, left_has_null, + r_range.map(|idx| if right_arr.is_null(idx) { None } else { Some(right_arr.value(idx)) })); + match result { + Some(v) => builder.append_value(v), + None => builder.append_null(), + } + } else { + builder.append_null(); + } + } + })+ + DataType::Utf8 => { + let left_arr = left_values.as_any().downcast_ref::().unwrap(); + let right_arr = right_values.as_any().downcast_ref::().unwrap(); + overlap_loop_string(left, right, left_arr, right_arr, &mut builder, len); + } + DataType::LargeUtf8 => { + let left_arr = left_values.as_any().downcast_ref::().unwrap(); + let right_arr = right_values.as_any().downcast_ref::().unwrap(); + overlap_loop_string(left, right, left_arr, right_arr, &mut builder, len); + } + _ => { + overlap_loop_scalar(left, right, left_values, right_values, &mut builder, len)?; + } + } + }; + } + + dispatch_and_loop! { + primitive: + Int8 => arrow::datatypes::Int8Type, + Int16 => arrow::datatypes::Int16Type, + Int32 => arrow::datatypes::Int32Type, + Int64 => arrow::datatypes::Int64Type, + UInt8 => arrow::datatypes::UInt8Type, + UInt16 => arrow::datatypes::UInt16Type, + UInt32 => arrow::datatypes::UInt32Type, + UInt64 => arrow::datatypes::UInt64Type, + Date32 => arrow::datatypes::Date32Type, + Date64 => arrow::datatypes::Date64Type, + } + + Ok(builder.finish()) +} + +/// Returns the left and right element ranges for row `i`, or None if either list is null. +fn row_ranges( + left: &GenericListArray, + right: &GenericListArray, + i: usize, +) -> Option<(Range, Range)> { + if left.is_null(i) || right.is_null(i) { + return None; + } + let l_start = left.value_offsets()[i].as_usize(); + let l_end = left.value_offsets()[i + 1].as_usize(); + let r_start = right.value_offsets()[i].as_usize(); + let r_end = right.value_offsets()[i + 1].as_usize(); + Some((l_start..l_end, r_start..r_end)) +} + +/// Populate a HashSet from the non-null primitive values in the given range. +/// Returns whether any null was encountered. +fn populate_set_primitive( + arr: &PrimitiveArray, + range: Range, + set: &mut HashSet, +) -> bool +where + T::Native: Eq + Hash, +{ + let mut has_null = false; + for idx in range { + if arr.is_null(idx) { + has_null = true; + } else { + set.insert(arr.value(idx)); + } + } + has_null +} + +/// String-typed overlap loop (handles both Utf8 and LargeUtf8 via GenericStringArray). +fn overlap_loop_string( + left: &GenericListArray, + right: &GenericListArray, + left_arr: &arrow::array::GenericStringArray, + right_arr: &arrow::array::GenericStringArray, + builder: &mut arrow::array::BooleanBuilder, + len: usize, +) { + let mut set: HashSet<&str> = HashSet::new(); + for i in 0..len { + if let Some((l_range, r_range)) = row_ranges(left, right, i) { + set.clear(); + let mut left_has_null = false; + for idx in l_range { + if left_arr.is_null(idx) { + left_has_null = true; + } else { + set.insert(left_arr.value(idx)); + } + } + let result = check_overlap( + &set, + left_has_null, + r_range.map(|idx| { + if right_arr.is_null(idx) { + None + } else { + Some(right_arr.value(idx)) + } + }), + ); + match result { + Some(v) => builder.append_value(v), + None => builder.append_null(), + } + } else { + builder.append_null(); + } + } +} + +/// Fallback loop using ScalarValue for types without specialized implementations. +fn overlap_loop_scalar( + left: &GenericListArray, + right: &GenericListArray, + left_values: &ArrayRef, + right_values: &ArrayRef, + builder: &mut arrow::array::BooleanBuilder, + len: usize, +) -> DataFusionResult<()> { + let mut set: HashSet = HashSet::new(); + for i in 0..len { + if let Some((l_range, r_range)) = row_ranges(left, right, i) { + set.clear(); + let mut left_has_null = false; + for idx in l_range { + if left_values.is_null(idx) { + left_has_null = true; + } else { + set.insert(ScalarValue::try_from_array(left_values.as_ref(), idx)?); + } + } + let mut right_has_null = false; + let mut found = false; + for idx in r_range { + if right_values.is_null(idx) { + right_has_null = true; + } else { + let sv = ScalarValue::try_from_array(right_values.as_ref(), idx)?; + if set.contains(&sv) { + found = true; + break; + } + } + } + if found { + builder.append_value(true); + } else if left_has_null || right_has_null { + builder.append_null(); + } else { + builder.append_value(false); + } + } else { + builder.append_null(); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::ListArray; + use arrow::datatypes::Int32Type; + use datafusion::common::Result; + + #[test] + fn test_no_overlap() -> Result<()> { + let left = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + ])]); + let right = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(3), + Some(4), + ])]); + let result = arrays_overlap_inner(&left, &right)?; + assert!(!result.value(0)); + assert!(!result.is_null(0)); + Ok(()) + } + + #[test] + fn test_overlap() -> Result<()> { + let left = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + ])]); + let right = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(3), + Some(4), + Some(5), + ])]); + let result = arrays_overlap_inner(&left, &right)?; + assert!(result.value(0)); + assert!(!result.is_null(0)); + Ok(()) + } + + #[test] + fn test_null_only_overlap_returns_null() -> Result<()> { + // array(1, NULL) vs array(NULL, 2): no common non-null, but both have nulls -> null + let left = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + ])]); + let right = ListArray::from_iter_primitive::(vec![Some(vec![ + None, + Some(2), + ])]); + let result = arrays_overlap_inner(&left, &right)?; + assert!(result.is_null(0)); + Ok(()) + } + + #[test] + fn test_null_array_returns_null() -> Result<()> { + let left = + ListArray::from_iter_primitive::(vec![None::>>]); + let right = ListArray::from_iter_primitive::(vec![Some(vec![Some(1)])]); + let result = arrays_overlap_inner(&left, &right)?; + assert!(result.is_null(0)); + Ok(()) + } + + #[test] + fn test_empty_array() -> Result<()> { + let left = ListArray::from_iter_primitive::(vec![Some(vec![])]); + let right = ListArray::from_iter_primitive::(vec![Some(vec![Some(1)])]); + let result = arrays_overlap_inner(&left, &right)?; + assert!(!result.value(0)); + assert!(!result.is_null(0)); + Ok(()) + } + + #[test] + fn test_full_scenario() -> Result<()> { + // Matches the issue: 5 rows + let left = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), // overlap with right -> true + Some(vec![Some(1), Some(2)]), // no overlap -> false + Some(vec![]), // empty -> false + None, // null array -> null + Some(vec![Some(1), None]), // null but no overlap -> null + ]); + let right = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(3), Some(4), Some(5)]), + Some(vec![Some(3), Some(4)]), + Some(vec![Some(1)]), + Some(vec![Some(1)]), + Some(vec![None, Some(2)]), + ]); + let result = arrays_overlap_inner(&left, &right)?; + + assert!(result.value(0)); + assert!(!result.is_null(0)); + + assert!(!result.value(1)); + assert!(!result.is_null(1)); + + assert!(!result.value(2)); + assert!(!result.is_null(2)); + + assert!(result.is_null(3)); + + assert!(result.is_null(4)); + + Ok(()) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 3ef50a252f..879ad9c12c 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -16,11 +16,13 @@ // under the License. mod array_insert; +mod arrays_overlap; mod get_array_struct_fields; mod list_extract; mod size; pub use array_insert::ArrayInsert; +pub use arrays_overlap::spark_arrays_overlap; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; pub use size::{spark_size, SparkSizeFunc}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index ff75de763b..383245b26c 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -181,6 +181,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) } + "array_has_any" => { + let func = Arc::new(crate::array_funcs::spark_arrays_overlap); + make_comet_scalar_udf!("array_has_any", func, without data_type) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", diff --git a/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql b/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql index 27d28a7402..4b83377457 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/arrays_overlap.sql @@ -24,11 +24,11 @@ CREATE TABLE test_arrays_overlap(a array, b array) USING parquet statement INSERT INTO test_arrays_overlap VALUES (array(1, 2, 3), array(3, 4, 5)), (array(1, 2), array(3, 4)), (array(), array(1)), (NULL, array(1)), (array(1, NULL), array(NULL, 2)) -query ignore(https://github.com/apache/datafusion-comet/issues/3645) +query SELECT arrays_overlap(a, b) FROM test_arrays_overlap -- column + literal -query ignore(https://github.com/apache/datafusion-comet/issues/3645) +query SELECT arrays_overlap(a, array(3, 4, 5)) FROM test_arrays_overlap -- literal + column