diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 544de6330b167..1fff6c22c5ad8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -517,7 +517,9 @@ case class EnsureRequirements( val (rightReducedDataTypes, rightReducedKeys) = rightReducers.fold( (rightPartitioning.expressionDataTypes, rightPartitioning.partitionKeys) )(rightPartitioning.reduceKeys) - if (leftReducedDataTypes != rightReducedDataTypes) { + val reducedDataTypes = if (leftReducedDataTypes == rightReducedDataTypes) { + leftReducedDataTypes + } else { throw QueryExecutionErrors.storagePartitionJoinIncompatibleReducedTypesError( leftReducers = leftReducers, leftReducedDataTypes = leftReducedDataTypes, @@ -525,9 +527,14 @@ case class EnsureRequirements( rightReducedDataTypes = rightReducedDataTypes) } + val reducedKeyRowOrdering = RowOrdering.createNaturalAscendingOrdering(reducedDataTypes) + val reducedKeyOrdering = + reducedKeyRowOrdering.on((t: InternalRowComparableWrapper) => t.row) + // merge values on both sides - var mergedPartitionKeys = mergeAndDedupPartitions(leftReducedKeys, rightReducedKeys, - joinType, leftPartitioning.keyOrdering).map((_, 1)) + var mergedPartitionKeys = + mergeAndDedupPartitions(leftReducedKeys, rightReducedKeys, joinType, reducedKeyOrdering) + .map((_, 1)) logInfo(log"After merging, there are " + log"${MDC(LogKeys.NUM_PARTITIONS, mergedPartitionKeys.size)} partitions") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index b59400bb0c3bd..21661c8a90574 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -75,17 +75,23 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with Column.create("dept_id", IntegerType), Column.create("data", StringType)) - def withFunction[T](fn: UnboundFunction)(f: => T): T = { - val id = Identifier.of(Array.empty, fn.name()) - val oldFn = Option.when(catalog.listFunctions(Array.empty).contains(id)) { - val fn = catalog.loadFunction(id) - catalog.dropFunction(id) - fn - } - catalog.createFunction(id, fn) + def withFunction[T](fns: UnboundFunction*)(f: => T): T = { + val fnIds = catalog.listFunctions(Array.empty) + val oldFns = fns.map { fn => + val id = Identifier.of(Array.empty, fn.name()) + val oldFn = Option.when(fnIds.contains(id)) { + val fn = catalog.loadFunction(id) + catalog.dropFunction(id) + fn + } + catalog.createFunction(id, fn) + (id, oldFn) + } try f finally { - catalog.dropFunction(id) - oldFn.foreach(catalog.createFunction(id, _)) + oldFns.foreach { case (id, oldFn) => + catalog.dropFunction(id) + oldFn.foreach(catalog.createFunction(id, _)) + } } } @@ -3441,7 +3447,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } test("SPARK-56046: Reducers with different result types") { - withFunction(UnboundDaysFunctionWithIncompatibleResultTypeReducer) { + withFunction(UnboundDaysFunctionWithToYearsReducerWithDateResult) { val items_partitions = Array(days("arrive_time")) createTable(items, itemsColumns, items_partitions) sql(s"INSERT INTO testcat.ns.$items VALUES " + @@ -3478,4 +3484,48 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } } + + test("SPARK-56164: Reducers with different result types to original keys") { + withFunction( + UnboundDaysFunctionWithToYearsReducerWithLongResult, + UnboundYearsFunctionWithToYearsReducerWithLongResult) { + val items_partitions = Array(days("arrive_time")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))") + + val purchases_partitions = Array(years("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp)), " + + s"(7, 46.5, cast('2021-02-08' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + Seq( + s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.time = i.arrive_time", + s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.arrive_time = p.time" + ).foreach { joinString => + val df = sql( + s""" + |${selectWithMergeJoinHint("i", "p")} id, item_id + |FROM $joinString + |ORDER BY id, item_id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 2)) + + checkAnswer(df, Seq(Row(0, 1), Row(1, 1))) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index a87f77529407e..35102c6893d3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -24,23 +24,34 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -object UnboundYearsFunction extends UnboundFunction { +abstract class UnboundYearsFunctionBase extends UnboundFunction { + protected def isValidType(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _ => false + } + + override def description(): String = name() + override def name(): String = "years" +} + +object UnboundYearsFunction extends UnboundYearsFunctionBase { override def bind(inputType: StructType): BoundFunction = { if (inputType.size == 1 && isValidType(inputType.head.dataType)) YearsFunction else throw new UnsupportedOperationException( "'years' only take date or timestamp as input type") } +} - private def isValidType(dt: DataType): Boolean = dt match { - case DateType | TimestampType => true - case _ => false +object UnboundYearsFunctionWithToYearsReducerWithLongResult extends UnboundYearsFunctionBase { + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 1 && isValidType(inputType.head.dataType)) { + YearsFunctionWithToYearsReducerWithLongResult + } else throw new UnsupportedOperationException( + "'years' only take date or timestamp as input type") } - - override def description(): String = name() - override def name(): String = "years" } -object YearsFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { +abstract class YearsFunctionBase[O] extends ScalarFunction[Int] with ReducibleFunction[Int, O] { override def inputTypes(): Array[DataType] = Array(TimestampType) override def resultType(): DataType = IntegerType override def name(): String = "years" @@ -49,14 +60,37 @@ object YearsFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int val UTC: ZoneId = ZoneId.of("UTC") val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate - def invoke(ts: Long): Int = { + protected def doInvoke(ts: Long): Long = { val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) } +} +// This `years` function reduces `IntegerType` partition keys to `IntegerType` partition keys when +// partitions are reduced to partitions of a `days` function, which produces `DateType` keys. +object YearsFunction extends YearsFunctionBase[Int] { + def invoke(ts: Long): Int = doInvoke(ts).toInt override def reducer(otherFunction: ReducibleFunction[_, _]): Reducer[Int, Int] = null } +// This `years` function reduces `IntegerType` partition keys to `LongType` partition keys when +// partitions are reduced to partitions of a `days` function, which produces `DateType` keys. +object YearsFunctionWithToYearsReducerWithLongResult extends YearsFunctionBase[Long] { + def invoke(ts: Long): Int = doInvoke(ts).toInt + override def reducer(otherFunction: ReducibleFunction[_, _]): Reducer[Int, Long] = { + if (otherFunction == DaysFunctionWithToYearsReducerWithLongResult) { + YearsToYearsReducerWithLongResult() + } else { + null + } + } +} + +case class YearsToYearsReducerWithLongResult() extends Reducer[Int, Long] { + override def resultType(): DataType = LongType + override def reduce(days: Int): Long = days.toLong +} + abstract class UnboundDaysFunctionBase extends UnboundFunction { protected def isValidType(dt: DataType): Boolean = dt match { case DateType | TimestampType => true @@ -75,16 +109,25 @@ object UnboundDaysFunction extends UnboundDaysFunctionBase { } } -object UnboundDaysFunctionWithIncompatibleResultTypeReducer extends UnboundDaysFunctionBase { +object UnboundDaysFunctionWithToYearsReducerWithDateResult extends UnboundDaysFunctionBase { + override def bind(inputType: StructType): BoundFunction = { + if (inputType.size == 1 && isValidType(inputType.head.dataType)) { + DaysFunctionWithToYearsReducerWithDateResult + } else throw new UnsupportedOperationException( + "'days' only take date or timestamp as input type") + } +} + +object UnboundDaysFunctionWithToYearsReducerWithLongResult extends UnboundDaysFunctionBase { override def bind(inputType: StructType): BoundFunction = { if (inputType.size == 1 && isValidType(inputType.head.dataType)) { - DaysFunctionWithIncompatibleResultTypeReducer + DaysFunctionWithToYearsReducerWithLongResult } else throw new UnsupportedOperationException( "'days' only take date or timestamp as input type") } } -abstract class DaysFunctionBase extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { +abstract class DaysFunctionBase[O] extends ScalarFunction[Int] with ReducibleFunction[Int, O] { override def inputTypes(): Array[DataType] = Array(TimestampType) override def resultType(): DataType = DateType override def name(): String = "days" @@ -93,7 +136,7 @@ abstract class DaysFunctionBase extends ScalarFunction[Int] with ReducibleFuncti // This `days` function reduces `DateType` partition keys to `IntegerType` partition keys when // partitions are reduced to partitions of a `years` function, which produces `IntegerType` keys. -object DaysFunction extends DaysFunctionBase { +object DaysFunction extends DaysFunctionBase[Int] { override def reducer(otherFunc: ReducibleFunction[_, _]): Reducer[Int, Int] = { if (otherFunc == YearsFunction) { DaysToYearsReducer() @@ -105,32 +148,51 @@ object DaysFunction extends DaysFunctionBase { // This `days` function reduces `DateType` partition keys to `DateType` partition keys when // partitions are reduced to partitions of a `years` function, which produces `IntegerType` keys. -object DaysFunctionWithIncompatibleResultTypeReducer extends DaysFunctionBase { +object DaysFunctionWithToYearsReducerWithDateResult extends DaysFunctionBase[Int] { override def reducer(otherFunc: ReducibleFunction[_, _]): Reducer[Int, Int] = { if (otherFunc == YearsFunction) { - DaysToYearsReducerWithIncompatibleResultType() + DaysToYearsReducerWithDateResult() } else { null } } } -abstract class DaysToYearsReducerBase extends Reducer[Int, Int] { +// This `days` function reduces `DateType` partition keys to `LongType` partition keys when +// partitions are reduced to partitions of a `years` function, which produces `IntegerType` keys. +object DaysFunctionWithToYearsReducerWithLongResult extends DaysFunctionBase[Long] { + override def reducer(otherFunc: ReducibleFunction[_, _]): Reducer[Int, Long] = { + if (otherFunc == YearsFunctionWithToYearsReducerWithLongResult) { + DaysToYearsReducerWithLongResult() + } else { + null + } + } +} + +abstract class DaysToYearsReducerBase[O] extends Reducer[Int, O] { val UTC: ZoneId = ZoneId.of("UTC") val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate - override def reduce(days: Int): Int = { + protected def doReduce(days: Int): Long = { val localDate = EPOCH_LOCAL_DATE.plusDays(days) - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate).toInt + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) } } -case class DaysToYearsReducer() extends DaysToYearsReducerBase { +case class DaysToYearsReducer() extends DaysToYearsReducerBase[Int] { override def resultType(): DataType = IntegerType + override def reduce(days: Int): Int = doReduce(days).toInt } -case class DaysToYearsReducerWithIncompatibleResultType() extends DaysToYearsReducerBase { +case class DaysToYearsReducerWithDateResult() extends DaysToYearsReducerBase[Int] { override def resultType(): DataType = DateType + override def reduce(days: Int): Int = doReduce(days).toInt +} + +case class DaysToYearsReducerWithLongResult() extends DaysToYearsReducerBase[Long] { + override def resultType(): DataType = LongType + override def reduce(days: Int): Long = doReduce(days) } object UnboundBucketFunction extends UnboundFunction {