Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,11 @@ case class KeyedShuffleSpec(
} else {
left.isSameFunction(right)
}
case (_: AttributeReference, _: TransformExpression) |
(_: TransformExpression, _: AttributeReference) =>
SQLConf.get.v2BucketingPushPartValuesEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
SQLConf.get.v2BucketingAllowCompatibleTransforms
case _ => false
}

Expand All @@ -1042,10 +1047,25 @@ case class KeyedShuffleSpec(
* @param other other key-grouped shuffle spec
*/
def reducers(other: KeyedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
val results = partitioning.expressions.zip(other.partitioning.expressions).map {
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
case (_, _) => None
}
val results = partitioning.expressions.zip(other.partitioning.expressions).map {
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)

// Identity transform on this side, arbitrary transform on the other side: create a reducer
// that applies the other's transform to the raw identity values. The symmetric case
// (TransformExpression, AttributeReference) is handled when the other side calls reducers.
// Each partition expression is guaranteed to have exactly one leaf child (asserted in
// keyPositions), so `a` lives at position 0 in the row we construct.
case (a: AttributeReference, t: TransformExpression) =>
val reducerExpr = t.transform { case _: AttributeReference => a }
val boundExpr = BindReferences.bindReference(reducerExpr, AttributeSeq(Seq(a)))
Some(new Reducer[Any, Any] {
override def reduce(v: Any): Any = boundExpr.eval(new GenericInternalRow(Array[Any](v)))
override def resultType(): DataType = reducerExpr.dataType
override def displayName(): String = reducerExpr.toString
})

case (_, _) => None
}

// optimize to not return a value, if none of the partition expressions are reducible
if (results.forall(p => p.isEmpty)) None else Some(results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3528,4 +3528,44 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase with
}
}
}

test("SPARK-56182: Reduce identity to other transforms") {
val items_partitions = Array(bucket(4, "id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " +
s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
s"(2, 'bb', 41.0, cast('2021-01-03' as timestamp)), " +
s"(3, 'bb', 42.0, cast('2021-01-04' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
s"(3, 42.0, cast('2020-01-01' as timestamp)), " +
s"(0, 44.0, cast('2020-01-15' as timestamp)), " +
s"(1, 46.5, cast('2021-02-08' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") {
Seq(
s"testcat.ns.$items i JOIN testcat.ns.$purchases p ON p.item_id = i.id",
s"testcat.ns.$purchases p JOIN testcat.ns.$items i ON i.id = p.item_id"
).foreach { joinString =>
val df = sql(
s"""
|${selectWithMergeJoinHint("i", "p")} id, item_id
|FROM $joinString
|ORDER BY id, item_id
|""".stripMargin)

val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
val groupPartitions = collectGroupPartitions(df.queryExecution.executedPlan)
assert(groupPartitions.forall(_.outputPartitioning.numPartitions == 4))

checkAnswer(df, Seq(Row(0, 0), Row(1, 1), Row(3, 3)))
}
}
}
}