Skip to content

Commit c13f167

Browse files
committed
[SPARK-56182][SQL] Allow SPJ reducing identity to other transforms
1 parent f7c2cc4 commit c13f167

2 files changed

Lines changed: 64 additions & 4 deletions

File tree

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,11 @@ case class KeyedShuffleSpec(
10221022
} else {
10231023
left.isSameFunction(right)
10241024
}
1025+
case (_: AttributeReference, _: TransformExpression) |
1026+
(_: TransformExpression, _: AttributeReference) =>
1027+
SQLConf.get.v2BucketingPushPartValuesEnabled &&
1028+
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
1029+
SQLConf.get.v2BucketingAllowCompatibleTransforms
10251030
case _ => false
10261031
}
10271032

@@ -1042,10 +1047,25 @@ case class KeyedShuffleSpec(
10421047
* @param other other key-grouped shuffle spec
10431048
*/
10441049
def reducers(other: KeyedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
1045-
val results = partitioning.expressions.zip(other.partitioning.expressions).map {
1046-
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
1047-
case (_, _) => None
1048-
}
1050+
val results = partitioning.expressions.zip(other.partitioning.expressions).map {
1051+
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
1052+
1053+
// Identity transform on this side, arbitrary transform on the other side: create a reducer
1054+
// that applies the other's transform to the raw identity values. The symmetric case
1055+
// (TransformExpression, AttributeReference) is handled when the other side calls reducers.
1056+
// Each partition expression is guaranteed to have exactly one leaf child (asserted in
1057+
// keyPositions), so `a` lives at position 0 in the row we construct.
1058+
case (a: AttributeReference, t: TransformExpression) =>
1059+
val reducerExpr = t.transform { case _: AttributeReference => a }
1060+
val boundExpr = BindReferences.bindReference(reducerExpr, AttributeSeq(Seq(a)))
1061+
Some(new Reducer[Any, Any] {
1062+
override def reduce(v: Any): Any = boundExpr.eval(new GenericInternalRow(Array[Any](v)))
1063+
override def resultType(): DataType = reducerExpr.dataType
1064+
override def displayName(): String = reducerExpr.toString
1065+
})
1066+
1067+
case (_, _) => None
1068+
}
10491069

10501070
// optimize to not return a value, if none of the partition expressions are reducible
10511071
if (results.forall(p => p.isEmpty)) None else Some(results)

sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3528,4 +3528,44 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
35283528
}
35293529
}
35303530
}
3531+
3532+
test("SPARK-56182: Reduce identity to other transforms") {
3533+
val items_partitions = Array(bucket(4, "id"))
3534+
createTable(items, itemsColumns, items_partitions)
3535+
sql(s"INSERT INTO testcat.ns.$items VALUES " +
3536+
s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " +
3537+
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
3538+
s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " +
3539+
s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))")
3540+
3541+
val purchases_partitions = Array(identity("item_id"))
3542+
createTable(purchases, purchasesColumns, purchases_partitions)
3543+
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
3544+
s"(3, 42.0, cast('2020-01-01' as timestamp)), " +
3545+
s"(0, 44.0, cast('2020-01-15' as timestamp)), " +
3546+
s"(1, 46.5, cast('2021-02-08' as timestamp))")
3547+
3548+
withSQLConf(
3549+
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
3550+
SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
3551+
Seq(
3552+
s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.item_id = i.id",
3553+
s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.id = p.item_id"
3554+
).foreach { joinString =>
3555+
val df = sql(
3556+
s"""
3557+
|${selectWithMergeJoinHint("i", "p")} id, item_id
3558+
|FROM $joinString
3559+
|ORDER BY id, item_id
3560+
|""".stripMargin)
3561+
3562+
val shuffles = collectShuffles(df.queryExecution.executedPlan)
3563+
assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
3564+
val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan)
3565+
assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 4))
3566+
3567+
checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(3, 3)))
3568+
}
3569+
}
3570+
}
35313571
}

0 commit comments

Comments
 (0)