From 5f66242c0b22142cb9904baa9ca26a1caf35a2ba Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 27 Dec 2025 15:10:23 +0530 Subject: [PATCH 1/5] Allow Spark partial / Comet final for compatible aggregates --- .../serde/CometAggregateExpressionSerde.scala | 14 ++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 17 ++++++++++ .../org/apache/comet/serde/aggregates.scala | 25 ++++++++++++++ .../apache/spark/sql/comet/operators.scala | 11 +++--- .../comet/rules/CometExecRuleSuite.scala | 34 +++++++++++++++++-- 5 files changed, 95 insertions(+), 6 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala index 0a5a2770b4..fc8f776ca7 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala @@ -49,6 +49,20 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] { */ def getSupportLevel(expr: T): SupportLevel = Compatible(None) + /** + * Indicates whether this aggregate function supports "Spark partial / Comet final" mixed + * execution. This requires the intermediate buffer format to be compatible between Spark and + * Comet. + * + * Only aggregates with simple, compatible intermediate buffers should return true. Aggregates + * with complex buffers or those with known incompatibilities (e.g., decimal overflow handling + * differences) should return false. + * + * @return + * true if the aggregate can safely run with Spark partial and Comet final, false otherwise + */ + def supportsSparkPartialCometFinal: Boolean = false + /** * Convert a Spark expression into a protocol buffer representation that can be passed into * native code. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8c39ba779d..e37f6709e6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -269,6 +269,23 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[VariancePop] -> CometVariancePop, classOf[VarianceSamp] -> CometVarianceSamp) + /** + * Checks if the given aggregate function supports "Spark partial / Comet final" mixed + * execution. This is used to determine if Comet can process a final aggregate even when the + * partial aggregate was performed by Spark. + * + * @param fn + * The aggregate function to check + * @return + * true if the aggregate supports mixed execution, false otherwise + */ + def aggSupportsMixedExecution(fn: AggregateFunction): Boolean = { + aggrSerdeMap.get(fn.getClass) match { + case Some(handler) => handler.supportsSparkPartialCometFinal + case None => false + } + } + // A unique id for each expression. ~used to look up QueryContext during error creation. private val exprIdCounter = new AtomicLong(0) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 1485589b46..4f9d5a3c8a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -34,6 +34,9 @@ import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { + // Min has a simple intermediate buffer (single value) compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Min, @@ -81,6 +84,9 @@ object CometMin extends CometAggregateExpressionSerde[Min] { object CometMax extends CometAggregateExpressionSerde[Max] { + // Max has a simple intermediate buffer (single value) compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Max, @@ -127,6 +133,10 @@ object CometMax extends CometAggregateExpressionSerde[Max] { } object CometCount extends CometAggregateExpressionSerde[Count] { + + // Count has a simple intermediate buffer (single Long) compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Count, @@ -306,6 +316,11 @@ object CometLast extends CometAggregateExpressionSerde[Last] { } object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { + + // BitAnd has a simple intermediate buffer (single integral value) + // compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitAnd: BitAndAgg, @@ -340,6 +355,11 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { } object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { + + // BitOr has a simple intermediate buffer (single integral value) + // compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitOr: BitOrAgg, @@ -374,6 +394,11 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { } object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] { + + // BitXor has a simple intermediate buffer (single integral value) + // compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitXor: BitXorAgg, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..f757ad51be 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -55,7 +55,7 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, with import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} -import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType} +import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, aggSupportsMixedExecution, exprToProto, supportedSortType} import org.apache.comet.serde.operator.CometSink /** @@ -1317,11 +1317,14 @@ trait CometBaseAggregate { val modes = aggregate.aggregateExpressions.map(_.mode).distinct // In distinct aggregates there can be a combination of modes val multiMode = modes.size > 1 - // For a final mode HashAggregate, we only need to transform the HashAggregate - // if there is Comet partial aggregation. + // For a final mode HashAggregate, check if there is Comet partial aggregation. + // If not, we can still proceed if all aggregates support mixed execution + // (Spark partial / Comet final). See https://github.com/apache/datafusion-comet/issues/2894 val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + val allSupportMixedExecution = aggregate.aggregateExpressions.forall(expr => + aggSupportsMixedExecution(expr.aggregateFunction)) - if (multiMode || sparkFinalMode) { + if (multiMode || (sparkFinalMode && !allSupportMixedExecution)) { return None } diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index cf6f8918f4..a343f7fb93 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -157,7 +157,9 @@ class CometExecRuleSuite extends CometTestBase { } } - test("CometExecRule should not allow Spark partial and Comet final hash aggregate") { + test("CometExecRule should not allow Spark partial and Comet final for unsafe aggregates") { + // https://github.com/apache/datafusion-comet/issues/2894 + // SUM is not safe for mixed execution due to potential overflow handling differences withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") @@ -173,7 +175,7 @@ class CometExecRuleSuite extends CometTestBase { CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { val transformedPlan = applyCometExecRule(sparkPlan) - // if the partial aggregate cannot be converted to Comet, then neither should be + // SUM is not safe for mixed execution, so both partial and final should fall back assert( countOperators(transformedPlan, classOf[HashAggregateExec]) == originalHashAggCount) assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) @@ -181,6 +183,34 @@ class CometExecRuleSuite extends CometTestBase { } } + test("CometExecRule should allow Spark partial and Comet final for safe aggregates") { + // https://github.com/apache/datafusion-comet/issues/2894 + // MIN, MAX, COUNT are safe for mixed execution (simple intermediate buffer) + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = + createSparkPlan( + spark, + "SELECT MIN(id), MAX(id), COUNT(*) FROM test_data GROUP BY (id % 3)") + + // Count original Spark operators (should be 2: partial + final) + val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec]) + assert(originalHashAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // MIN, MAX, COUNT support mixed execution, so final should run in Comet + // Partial stays in Spark (1 HashAggregateExec), final runs in Comet (1 CometHashAggregateExec) + assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == 1) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1) + } + } + } + test("CometExecRule should apply broadcast exchange transformations") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") From f9968bef0126fdf49d8d1c9a7dcb4934ff0a71f3 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Fri, 9 Jan 2026 18:47:16 +0530 Subject: [PATCH 2/5] Add unit tests for aggSupportsMixedExecution including bitwise aggregates --- .../comet/rules/CometExecRuleSuite.scala | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index a343f7fb93..4c3980165c 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -22,6 +22,8 @@ package org.apache.comet.rules import scala.util.Random import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Max, Min, Sum} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ @@ -29,8 +31,10 @@ import org.apache.spark.sql.execution.adaptive.QueryStageExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.types.{DataTypes, StructField, StructType} +import org.apache.spark.sql.types.IntegerType import org.apache.comet.CometConf +import org.apache.comet.serde.QueryPlanSerde import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} /** @@ -211,6 +215,65 @@ class CometExecRuleSuite extends CometTestBase { } } + test("QueryPlanSerde.aggSupportsMixedExecution should return true for safe aggregates") { + // Test aggregates that support Spark partial / Comet final execution + val testAttr = AttributeReference("id", IntegerType, nullable = true)() + assert( + QueryPlanSerde.aggSupportsMixedExecution(Min(testAttr)), + "Min should support mixed execution") + assert( + QueryPlanSerde.aggSupportsMixedExecution(Max(testAttr)), + "Max should support mixed execution") + assert( + QueryPlanSerde.aggSupportsMixedExecution(Count(testAttr)), + "Count should support mixed execution") + assert( + QueryPlanSerde.aggSupportsMixedExecution(BitAndAgg(testAttr)), + "BitAndAgg should support mixed execution") + assert( + QueryPlanSerde.aggSupportsMixedExecution(BitOrAgg(testAttr)), + "BitOrAgg should support mixed execution") + assert( + QueryPlanSerde.aggSupportsMixedExecution(BitXorAgg(testAttr)), + "BitXorAgg should support mixed execution") + } + + test("QueryPlanSerde.aggSupportsMixedExecution should return true for bitwise aggregates") { + // Test bitwise aggregates that support Spark partial / Comet final execution + val testAttr = AttributeReference("id", IntegerType, nullable = true)() + assert( + QueryPlanSerde.aggSupportsMixedExecution(BitAndAgg(testAttr)), + "BitAndAgg should support mixed execution") + assert( + QueryPlanSerde.aggSupportsMixedExecution(BitOrAgg(testAttr)), + "BitOrAgg should support mixed execution") + assert( + QueryPlanSerde.aggSupportsMixedExecution(BitXorAgg(testAttr)), + "BitXorAgg should support mixed execution") + } + + test("QueryPlanSerde.aggSupportsMixedExecution should return false for unsafe aggregates") { + // Test aggregates that don't support Spark partial / Comet final execution + val testAttr = AttributeReference("id", IntegerType, nullable = true)() + assert( + !QueryPlanSerde.aggSupportsMixedExecution(Sum(testAttr)), + "Sum should not support mixed execution") + assert( + !QueryPlanSerde.aggSupportsMixedExecution(Average(testAttr)), + "Average should not support mixed execution") + } + + test( + "QueryPlanSerde.aggSupportsMixedExecution should return false for aggregates with default implementation") { + // Test with an aggregate function that's in the map but doesn't override supportsSparkPartialCometFinal + val testAttr = AttributeReference("id", IntegerType, nullable = true)() + val firstAgg = new org.apache.spark.sql.catalyst.expressions.aggregate.First(testAttr, true) + // First is in the map but doesn't override supportsSparkPartialCometFinal, so should return false + assert( + !QueryPlanSerde.aggSupportsMixedExecution(firstAgg), + "First should not support mixed execution (default false)") + } + test("CometExecRule should apply broadcast exchange transformations") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") From 24db310f9a2e5050166fa483b8d79fb765e455a3 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Thu, 22 Jan 2026 21:17:55 +0530 Subject: [PATCH 3/5] Fix build errors: Remove unused imports and fix feature flag conflicts --- .../test/scala/org/apache/comet/rules/CometExecRuleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 4c3980165c..1c6283a5fe 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -22,7 +22,7 @@ package org.apache.comet.rules import scala.util.Random import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, BitAndAgg, BitOrAgg, BitXorAgg, Count, Max, Min, Sum} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec From 141ba57f29165ff279ec25137b8d4932c1f67359 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Tue, 17 Mar 2026 00:38:24 +0530 Subject: [PATCH 4/5] Add bitwise aggregate mixed execution test --- .../comet/exec/CometAggregateSuite.scala | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 9426d1c848..a5afc557bf 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1292,6 +1292,34 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("bitwise aggregates allow Spark partial and Comet final") { + withSQLConf( + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "jvm", + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "true", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val table = "bitwise_mixed" + withTable(table) { + sql( + s"create table $table(col1 long, col2 int, col3 short, col4 byte, grp int) using parquet") + sql( + s"insert into $table values" + + "(4, 1, 1, 3, 0), (4, 1, 1, 3, 0), (3, 3, 1, 4, 1)," + + " (2, 4, 2, 5, 1), (1, 3, 2, 6, 0)") + + // Partial aggregate stays in Spark, final aggregate runs in Comet. + checkSparkAnswerAndNumOfAggregates( + s"SELECT grp, BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," + + " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," + + " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," + + " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4)" + + s" FROM $table GROUP BY grp", + 1) + } + } + } + def setupAndTestAggregates( table: String, data: Seq[(Any, Any, Any)], From 4eaef6b08665cba9664450d2053b687ba6cee7b0 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Wed, 18 Mar 2026 10:59:10 +0530 Subject: [PATCH 5/5] minor change --- .../test/scala/org/apache/comet/exec/CometAggregateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index a5afc557bf..6033725a2c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1310,7 +1310,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Partial aggregate stays in Spark, final aggregate runs in Comet. checkSparkAnswerAndNumOfAggregates( - s"SELECT grp, BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," + + "SELECT grp, BIT_AND(col1), BIT_OR(col1), BIT_XOR(col1)," + " BIT_AND(col2), BIT_OR(col2), BIT_XOR(col2)," + " BIT_AND(col3), BIT_OR(col3), BIT_XOR(col3)," + " BIT_AND(col4), BIT_OR(col4), BIT_XOR(col4)" +