From 84052f2fe80348611e91c458de12589324e0c9b4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 23 Mar 2026 20:41:03 +0100 Subject: [PATCH 1/5] [SPARK-56164][SQL] Fix SPJ merged key ordering --- .../exchange/EnsureRequirements.scala | 9 +- .../KeyGroupedPartitioningSuite.scala | 72 +++++++++++--- .../functions/transformFunctions.scala | 98 +++++++++++++++---- 3 files changed, 147 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 544de6330b167..1c4ab70c0ea7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -525,9 +525,14 @@ case class EnsureRequirements( rightReducedDataTypes = rightReducedDataTypes) } + val reducedKeyRowOrdering = RowOrdering.createNaturalAscendingOrdering(leftReducedDataTypes) + val reducedKeyOrdering = + reducedKeyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) + // merge values on both sides - var mergedPartitionKeys = mergeAndDedupPartitions(leftReducedKeys, rightReducedKeys, - joinType, leftPartitioning.keyOrdering).map((_, 1)) + var mergedPartitionKeys = + mergeAndDedupPartitions(leftReducedKeys, rightReducedKeys, joinType, reducedKeyOrdering) + .map((_, 1)) logInfo(log"After merging, there are " + log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.size)} partitions") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index b59400bb0c3bd..3f88dbc2e1357 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -75,17 +75,23 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with Column.create("dept_id", IntegerType), Column.create("data", StringType)) - def withFunction[T](fn: UnboundFunction)(f: => T): T = { - val id = Identifier.of(Array.empty, fn.name()) - val oldFn = Option.when(catalog.listFunctions(Array.empty).contains(id)) { - val fn = catalog.loadFunction(id) - catalog.dropFunction(id) - fn - } - catalog.createFunction(id, fn) + def withFunctions[T](fns: UnboundFunction*)(f: => T): T = { + val fnIds = catalog.listFunctions(Array.empty) + val oldFns = fns.map { fn => + val id = Identifier.of(Array.empty, fn.name()) + val oldFn = Option.when(fnIds.contains(id)) { + val fn = catalog.loadFunction(id) + catalog.dropFunction(id) + fn + } + catalog.createFunction(id, fn) + (id, oldFn) + } try f finally { - catalog.dropFunction(id) - oldFn.foreach(catalog.createFunction(id, _)) + oldFns.foreach { case (id, oldFn ) => + catalog.dropFunction(id) + oldFn.foreach(catalog.createFunction(id, _)) + } } } @@ -3441,7 +3447,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } test("SPARK-56046: Reducers with different result types") { - withFunction(UnboundDaysFunctionWithIncompatibleResultTypeReducer) { + withFunctions(UnboundDaysFunctionWithToYearsReducerWithDateResult) { val items_partitions = Array(days("arrive_time")) createTable(items, itemsColumns, items_partitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + @@ -3478,4 +3484,48 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } } + + test("SPARK-56164: Reducers with different result types to original keys") { + withFunctions( + UnboundDaysFunctionWithToYearsReducerWithLongResult, + UnboundYearsFunctionWithToYearsReducerWithLongResult) { + val items_partitions = Array(days("arrive_time")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))") + + val purchases_partitions = Array(years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp)), " + + s"(7, 46.5, cast('2021-02-08' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + Seq( + s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.time = i.arrive_time", + s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.arrive_time = p.time" + ).foreach { joinString => + val df = sql( + s""" + |${selectWithMergeJoinHint("i", "p")} id, item_id + |FROM $joinString + |ORDER BY id, item_id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 2)) + + checkAnswer(df, Seq(Row(0, 1), Row(1, 1))) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index a87f77529407e..43c0e06d931f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -24,23 +24,34 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -object UnboundYearsFunction extends UnboundFunction { +abstract class UnboundYearsFunctionBase extends UnboundFunction { + protected def isValidType(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _ => false + } + + override def description(): String = name() + override def name(): String = "years" +} + +object UnboundYearsFunction extends UnboundYearsFunctionBase { override def bind(inputType: StructType): BoundFunction = { if (inputType.size == 1 && isValidType(inputType.head.dataType)) YearsFunction else throw new UnsupportedOperationException( "'years' only take date or timestamp as input type") } +} - private def isValidType(dt: DataType): Boolean = dt match { - case DateType | TimestampType => true - case _ => false +object UnboundYearsFunctionWithToYearsReducerWithLongResult extends UnboundYearsFunctionBase { + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 1 && isValidType(inputType.head.dataType)) { + YearsFunctionWithToYearsReducerWithLongResult + } else throw new UnsupportedOperationException( + "'years' only take date or timestamp as input type") } - - override def description(): String = name() - override def name(): String = "years" } -object YearsFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { +abstract class YearsFunctionBase[O] extends ScalarFunction[Int] with ReducibleFunction[Int, O] { override def inputTypes(): Array[DataType] = Array(TimestampType) override def resultType(): DataType = IntegerType override def name(): String = "years" @@ -53,10 +64,31 @@ object YearsFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt } +} +// This `years` function reduces `IntegerType` partition keys to `IntegerType` partition keys when +// partitions are reduced to partitions of a `days` function, which produces `DateType` keys. +object YearsFunction extends YearsFunctionBase[Int] { override def reducer(otherFunction: ReducibleFunction[_, _]): Reducer[Int, Int] = null } +// This `years` function reduces `IntegerType` partition keys to `LongType` partition keys when +// partitions are reduced to partitions of a `days` function, which produces `DateType` keys. +object YearsFunctionWithToYearsReducerWithLongResult extends YearsFunctionBase[Long] { + override def reducer(otherFunction: ReducibleFunction[_, _]): Reducer[Int, Long] = { + if (otherFunction == DaysFunctionWithToYearsReducerWithLongResult) { + YearsToYearsReducerWithLogResult() + } else { + null + } + } +} + +case class YearsToYearsReducerWithLogResult() extends Reducer[Int, Long] { + override def resultType(): DataType = LongType + override def reduce(days: Int): Long = days.toLong +} + abstract class UnboundDaysFunctionBase extends UnboundFunction { protected def isValidType(dt: DataType): Boolean = dt match { case DateType | TimestampType => true @@ -75,16 +107,25 @@ object UnboundDaysFunction extends UnboundDaysFunctionBase { } } -object UnboundDaysFunctionWithIncompatibleResultTypeReducer extends UnboundDaysFunctionBase { +object UnboundDaysFunctionWithToYearsReducerWithDateResult extends UnboundDaysFunctionBase { + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 1 && isValidType(inputType.head.dataType)) { + DaysFunctionWithToYearsReducerWithDateResult + } else throw new UnsupportedOperationException( + "'days' only take date or timestamp as input type") + } +} + +object UnboundDaysFunctionWithToYearsReducerWithLongResult extends UnboundDaysFunctionBase { override def bind(inputType: StructType): BoundFunction = { if (inputType.size == 1 && isValidType(inputType.head.dataType)) { - DaysFunctionWithIncompatibleResultTypeReducer + DaysFunctionWithToYearsReducerWithLongResult } else throw new UnsupportedOperationException( "'days' only take date or timestamp as input type") } } -abstract class DaysFunctionBase extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { +abstract class DaysFunctionBase[O] extends ScalarFunction[Int] with ReducibleFunction[Int, O] { override def inputTypes(): Array[DataType] = Array(TimestampType) override def resultType(): DataType = DateType override def name(): String = "days" @@ -93,7 +134,7 @@ abstract class DaysFunctionBase extends ScalarFunction[Int] with ReducibleFuncti // This `days` function reduces `DateType` partition keys to `IntegerType` partition keys when // partitions are reduced to partitions of a `years` function, which produces `IntegerType` keys. -object DaysFunction extends DaysFunctionBase { +object DaysFunction extends DaysFunctionBase[Int] { override def reducer(otherFunc: ReducibleFunction[_, _]): Reducer[Int, Int] = { if (otherFunc == YearsFunction) { DaysToYearsReducer() @@ -105,32 +146,51 @@ object DaysFunction extends DaysFunctionBase { // This `days` function reduces `DateType` partition keys to `DateType` partition keys when // partitions are reduced to partitions of a `years` function, which produces `IntegerType` keys. -object DaysFunctionWithIncompatibleResultTypeReducer extends DaysFunctionBase { +object DaysFunctionWithToYearsReducerWithDateResult extends DaysFunctionBase[Int] { override def reducer(otherFunc: ReducibleFunction[_, _]): Reducer[Int, Int] = { if (otherFunc == YearsFunction) { - DaysToYearsReducerWithIncompatibleResultType() + DaysToYearsReducerWithDateResult() } else { null } } } -abstract class DaysToYearsReducerBase extends Reducer[Int, Int] { +// This `days` function reduces `DateType` partition keys to `LongType` partition keys when +// partitions are reduced to partitions of a `years` function, which produces `IntegerType` keys. +object DaysFunctionWithToYearsReducerWithLongResult extends DaysFunctionBase[Long] { + override def reducer(otherFunc: ReducibleFunction[_, _]): Reducer[Int, Long] = { + if (otherFunc == YearsFunctionWithToYearsReducerWithLongResult) { + DaysToYearsReducerWithLongResult() + } else { + null + } + } +} + +abstract class DaysToYearsReducerBase[O] extends Reducer[Int, O] { val UTC: ZoneId = ZoneId.of("UTC") val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate - override def reduce(days: Int): Int = { + protected def doReduce(days: Int): Long = { val localDate = EPOCH_LOCAL_DATE.plusDays(days) - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) } } -case class DaysToYearsReducer() extends DaysToYearsReducerBase { +case class DaysToYearsReducer() extends DaysToYearsReducerBase[Int] { override def resultType(): DataType = IntegerType + override def reduce(days: Int): Int = doReduce(days).toInt } -case class DaysToYearsReducerWithIncompatibleResultType() extends DaysToYearsReducerBase { +case class DaysToYearsReducerWithDateResult() extends DaysToYearsReducerBase[Int] { override def resultType(): DataType = DateType + override def reduce(days: Int): Int = doReduce(days).toInt +} + +case class DaysToYearsReducerWithLongResult() extends DaysToYearsReducerBase[Long] { + override def resultType(): DataType = LongType + override def reduce(days: Int): Long = doReduce(days) } object UnboundBucketFunction extends UnboundFunction { From 9e3085fbf64a63eab7ba1c05715433581f12bbf7 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 23 Mar 2026 21:22:36 +0100 Subject: [PATCH 2/5] use withFunction --- .../spark/sql/connector/KeyGroupedPartitioningSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 3f88dbc2e1357..643c9a22d1142 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -75,7 +75,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with Column.create("dept_id", IntegerType), Column.create("data", StringType)) - def withFunctions[T](fns: UnboundFunction*)(f: => T): T = { + def withFunction[T](fns: UnboundFunction*)(f: => T): T = { val fnIds = catalog.listFunctions(Array.empty) val oldFns = fns.map { fn => val id = Identifier.of(Array.empty, fn.name()) @@ -3447,7 +3447,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } test("SPARK-56046: Reducers with different result types") { - withFunctions(UnboundDaysFunctionWithToYearsReducerWithDateResult) { + withFunction(UnboundDaysFunctionWithToYearsReducerWithDateResult) { val items_partitions = Array(days("arrive_time")) createTable(items, itemsColumns, items_partitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + @@ -3486,7 +3486,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } test("SPARK-56164: Reducers with different result types to original keys") { - withFunctions( + withFunction( UnboundDaysFunctionWithToYearsReducerWithLongResult, UnboundYearsFunctionWithToYearsReducerWithLongResult) { val items_partitions = Array(days("arrive_time")) From 31fb4efe12d995cf5864534cd389ea999a04cf80 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 24 Mar 2026 09:42:04 +0100 Subject: [PATCH 3/5] address review findings --- .../spark/sql/execution/exchange/EnsureRequirements.scala | 6 ++++-- .../spark/sql/connector/KeyGroupedPartitioningSuite.scala | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 1c4ab70c0ea7c..1fff6c22c5ad8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -517,7 +517,9 @@ case class EnsureRequirements( val (rightReducedDataTypes, rightReducedKeys) = rightReducers.fold( (rightPartitioning.expressionDataTypes, rightPartitioning.partitionKeys) )(rightPartitioning.reduceKeys) - if (leftReducedDataTypes != rightReducedDataTypes) { + val reducedDataTypes = if (leftReducedDataTypes == rightReducedDataTypes) { + leftReducedDataTypes + } else { throw QueryExecutionErrors.storagePartitionJoinIncompatibleReducedTypesError( leftReducers = leftReducers, leftReducedDataTypes = leftReducedDataTypes, @@ -525,7 +527,7 @@ case class EnsureRequirements( rightReducedDataTypes = rightReducedDataTypes) } - val reducedKeyRowOrdering = RowOrdering.createNaturalAscendingOrdering(leftReducedDataTypes) + val reducedKeyRowOrdering = RowOrdering.createNaturalAscendingOrdering(reducedDataTypes) val reducedKeyOrdering = reducedKeyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 643c9a22d1142..21661c8a90574 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -88,7 +88,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with (id, oldFn) } try f finally { - oldFns.foreach { case (id, oldFn ) => + oldFns.foreach { case (id, oldFn) => catalog.dropFunction(id) oldFn.foreach(catalog.createFunction(id, _)) } From 5daa4a70617f70f5437236deb55f9bd8c9d35020 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 24 Mar 2026 09:43:02 +0100 Subject: [PATCH 4/5] fix test failures due to missing invoke --- .../connector/catalog/functions/transformFunctions.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 43c0e06d931f4..b274699293436 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -60,21 +60,23 @@ abstract class YearsFunctionBase[O] extends ScalarFunction[Int] with ReducibleFu val UTC: ZoneId = ZoneId.of("UTC") val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate - def invoke(ts: Long): Int = { + protected def doInvoke(ts: Long): Long = { val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) } } // This `years` function reduces `IntegerType` partition keys to `IntegerType` partition keys when // partitions are reduced to partitions of a `days` function, which produces `DateType` keys. object YearsFunction extends YearsFunctionBase[Int] { + def invoke(ts: Long): Int = doInvoke(ts).toInt override def reducer(otherFunction: ReducibleFunction[_, _]): Reducer[Int, Int] = null } // This `years` function reduces `IntegerType` partition keys to `LongType` partition keys when // partitions are reduced to partitions of a `days` function, which produces `DateType` keys. object YearsFunctionWithToYearsReducerWithLongResult extends YearsFunctionBase[Long] { + def invoke(ts: Long): Int = doInvoke(ts).toInt override def reducer(otherFunction: ReducibleFunction[_, _]): Reducer[Int, Long] = { if (otherFunction == DaysFunctionWithToYearsReducerWithLongResult) { YearsToYearsReducerWithLogResult() From ee98c7e25689137283258161925892a5d2b930f2 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 24 Mar 2026 11:16:15 +0100 Subject: [PATCH 5/5] fix name --- .../sql/connector/catalog/functions/transformFunctions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index b274699293436..35102c6893d3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -79,14 +79,14 @@ object YearsFunctionWithToYearsReducerWithLongResult extends YearsFunctionBase[L def invoke(ts: Long): Int = doInvoke(ts).toInt override def reducer(otherFunction: ReducibleFunction[_, _]): Reducer[Int, Long] = { if (otherFunction == DaysFunctionWithToYearsReducerWithLongResult) { - YearsToYearsReducerWithLogResult() + YearsToYearsReducerWithLongResult() } else { null } } } -case class YearsToYearsReducerWithLogResult() extends Reducer[Int, Long] { +case class YearsToYearsReducerWithLongResult() extends Reducer[Int, Long] { override def resultType(): DataType = LongType override def reduce(days: Int): Long = days.toLong }