Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 122 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,129 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] {
val arrayExprProto = exprToProto(expr.children.head, inputs, binding)
val keyExprProto = exprToProto(expr.children(1), inputs, binding)

val arrayContainsScalarExpr =
// Check if array is null - if so, return null
val isArrayNotNullExpr = createUnaryExpr(
expr,
expr.children.head,
inputs,
binding,
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))

// Check if search value is null - if so, return null
val isKeyNotNullExpr = createUnaryExpr(
expr,
expr.children(1),
inputs,
binding,
(builder, unaryExpr) => builder.setIsNotNull(unaryExpr))

// Check if value exists in array
val arrayHasValueExpr =
scalarFunctionExprToProto("array_has", arrayExprProto, keyExprProto)
optExprWithInfo(arrayContainsScalarExpr, expr, expr.children: _*)

// Detect null elements without `array_has(array, null)`: DataFusion's array_has returns
// null when the needle is null, so it cannot be used as a boolean predicate in CaseWhen.
// Removing nulls and comparing lengths matches Spark's indeterminate case.
val nullKeyLiteralProto = exprToProto(Literal(null, expr.children(1).dataType), Seq.empty)
val arrayType = expr.children.head.dataType.asInstanceOf[ArrayType]
val arrayWithoutNullsExpr = scalarFunctionExprToProtoWithReturnType(
"array_remove_all",
arrayType,
false,
arrayExprProto,
nullKeyLiteralProto)
val arraySizeExpr = scalarFunctionExprToProto("size", arrayExprProto)
val arraySizeWithoutNullsExpr = arrayWithoutNullsExpr.flatMap { removed =>
scalarFunctionExprToProto("size", Some(removed))
}
val arrayHasNullElementExpr = for {
s0 <- arraySizeExpr
s1 <- arraySizeWithoutNullsExpr
} yield {
val neq = ExprOuterClass.BinaryExpr
.newBuilder()
.setLeft(s0)
.setRight(s1)
.build()
ExprOuterClass.Expr.newBuilder().setNeq(neq).build()
}

// Build the three-valued logic:
// 1. If array is null -> return null
// 2. If key is null -> return null
// 3. If array_has(array, key) is true -> return true
// 4. If array_has(array, key) is false AND the array still contains null elements
// -> return null (indeterminate)
// 5. If array_has(array, key) is false AND there are no null elements -> return false
if (isArrayNotNullExpr.isDefined && isKeyNotNullExpr.isDefined &&
arrayHasValueExpr.isDefined && arrayHasNullElementExpr.isDefined &&
nullKeyLiteralProto.isDefined) {
// Create boolean literals
val trueLiteralProto = exprToProto(Literal(true, BooleanType), Seq.empty)
val falseLiteralProto = exprToProto(Literal(false, BooleanType), Seq.empty)
val nullBooleanLiteralProto = exprToProto(Literal(null, BooleanType), Seq.empty)

if (trueLiteralProto.isDefined && falseLiteralProto.isDefined &&
nullBooleanLiteralProto.isDefined) {
// If array_has(array, key) is false, check if array has nulls
// If size(array) != size(array_remove_all(array, null)) -> return null, else false
val whenNotFoundCheckNulls = ExprOuterClass.CaseWhen
.newBuilder()
.addWhen(arrayHasNullElementExpr.get) // if array has null elements
.addThen(nullBooleanLiteralProto.get) // return null (indeterminate)
.setElseExpr(falseLiteralProto.get) // else return false
.build()

// If array_has(array, key) is true, return true, else check null case
val whenValueFound = ExprOuterClass.CaseWhen
.newBuilder()
.addWhen(arrayHasValueExpr.get) // if value found
.addThen(trueLiteralProto.get) // return true
.setElseExpr(
ExprOuterClass.Expr
.newBuilder()
.setCaseWhen(whenNotFoundCheckNulls)
.build()
) // else check null case
.build()

// Check if key is null -> return null, else use the logic above
val whenKeyNotNull = ExprOuterClass.CaseWhen
.newBuilder()
.addWhen(isKeyNotNullExpr.get) // if key is not null
.addThen(
ExprOuterClass.Expr
.newBuilder()
.setCaseWhen(whenValueFound)
.build())
.setElseExpr(nullBooleanLiteralProto.get) // key is null -> return null
.build()

// Outer case: if array is null, return null, else use the logic above
val outerCaseWhen = ExprOuterClass.CaseWhen
.newBuilder()
.addWhen(isArrayNotNullExpr.get) // if array is not null
.addThen(
ExprOuterClass.Expr
.newBuilder()
.setCaseWhen(whenKeyNotNull)
.build())
.setElseExpr(nullBooleanLiteralProto.get) // array is null -> return null
.build()

Some(
ExprOuterClass.Expr
.newBuilder()
.setCaseWhen(outerCaseWhen)
.build())
} else {
withInfo(expr, expr.children: _*)
None
}
} else {
withInfo(expr, expr.children: _*)
None
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,48 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_contains - three-valued null logic") {
// Test Spark's three-valued logic for array_contains:
// 1. Returns true if value is found
// 2. Returns false if no match found AND no null elements exist
// 3. Returns null if no match found BUT null elements exist (indeterminate)
// 4. Returns null if search value is null
withTempDir { dir =>
withTempView("t1") {
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 100)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")

// Disable constant folding to ensure Comet implementation is exercised
withSQLConf(
SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
"org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
// Test case 1: value found -> returns true
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3), 2) FROM t1"))

// Test case 2: no match, no nulls -> returns false
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3), 5) FROM t1"))

// Test case 3: no match, but null exists -> returns null (indeterminate)
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, null, 3), 2) FROM t1"))

// Test case 4: match found even with nulls -> returns true
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, null, 3), 1) FROM t1"))

// Test case 5: search value is null -> returns null
checkSparkAnswerAndOperator(
sql("SELECT array_contains(array(1, 2, 3), cast(null as int)) FROM t1"))

// Test case 6: array with nulls, searching for existing value -> returns true
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, null, 3), 3) FROM t1"))

// Test case 7: empty array -> returns false
checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM t1"))
}
}
}
}

test("array_contains - test all types (convert from Parquet)") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
Expand Down