diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 8b11860d0eb87..7804aec263d0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -1022,6 +1022,11 @@ case class KeyedShuffleSpec( } else { left.isSameFunction(right) } + case (_: AttributeReference, _: TransformExpression) | + (_: TransformExpression, _: AttributeReference) => + SQLConf.get.v2BucketingPushPartValuesEnabled && + !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && + SQLConf.get.v2BucketingAllowCompatibleTransforms case _ => false } @@ -1042,10 +1047,25 @@ case class KeyedShuffleSpec( * @param other other key-grouped shuffle spec */ def reducers(other: KeyedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { - val results = partitioning.expressions.zip(other.partitioning.expressions).map { - case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) - case (_, _) => None - } + val results = partitioning.expressions.zip(other.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) + + // Identity transform on this side, arbitrary transform on the other side: create a reducer + // that applies the other's transform to the raw identity values. The symmetric case + // (TransformExpression, AttributeReference) is handled when the other side calls reducers. + // Each partition expression is guaranteed to have exactly one leaf child (asserted in + // keyPositions), so `a` lives at position 0 in the row we construct. + case (a: AttributeReference, t: TransformExpression) => + val reducerExpr = t.transform { case _: AttributeReference => a } + val boundExpr = BindReferences.bindReference(reducerExpr, AttributeSeq(Seq(a))) + Some(new Reducer[Any, Any] { + override def reduce(v: Any): Any = boundExpr.eval(new GenericInternalRow(Array[Any](v))) + override def resultType(): DataType = reducerExpr.dataType + override def displayName(): String = reducerExpr.toString + }) + + case (_, _) => None + } // optimize to not return a value, if none of the partition expressions are reducible if (results.forall(p => p.isEmpty)) None else Some(results) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 21661c8a90574..688196b47502e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -3528,4 +3528,44 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with } } } + + test("SPARK-56182: Reduce identity to other transforms") { + val items_partitions = Array(bucket(4, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))") + + val purchases_partitions = Array(identity("item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(3, 42.0, cast('2020-01-01' as timestamp)), " + + s"(0, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 46.5, cast('2021-02-08' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + Seq( + s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.item_id = i.id", + s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.id = p.item_id" + ).foreach { joinString => + val df = sql( + s""" + |${selectWithMergeJoinHint("i", "p")} id, item_id + |FROM $joinString + |ORDER BY id, item_id + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan) + assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 4)) + + checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(3, 3))) + } + } + } }