diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3c210ca7d985b..ddfa002c54dca 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -8331,6 +8331,12 @@ ], "sqlState" : "42KDF" }, + "ZIP_PLANS_NOT_MERGEABLE" : { + "message" : [ + "The two DataFrames in zip() cannot be merged because they do not derive from the same base plan through Project operations." + ], + "sqlState" : "42K03" + }, "_LEGACY_ERROR_TEMP_0001" : { "message" : [ "Invalid InsertIntoContext." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8ef7e8bdc083f..dbec40c42e3cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -498,6 +498,7 @@ class Analyzer( ResolveBinaryArithmetic :: new ResolveIdentifierClause(earlyBatches) :: ResolveUnion :: + ResolveZip :: FlattenSequentialStreamingUnion :: ValidateSequentialStreamingUnion :: ResolveRowLevelCommandAssignments :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 94c40215ab59f..c0715e0fa9ddf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -546,6 +546,19 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString messageParameters = Map.empty) } + case z: Zip => + def stripProjects(plan: LogicalPlan): LogicalPlan = plan match { + case Project(_, child) => stripProjects(child) + case other => other + } + val leftBase = stripProjects(z.left) + val rightBase = stripProjects(z.right) + if (!leftBase.sameResult(rightBase)) { + z.failAnalysis( + errorClass = "ZIP_PLANS_NOT_MERGEABLE", + messageParameters = Map.empty) + } + case a: Aggregate => a.groupingExpressions.foreach( expression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveZip.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveZip.scala new file mode 100644 index 0000000000000..c770c7df83f12 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveZip.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Zip} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.ZIP + +/** + * Resolves a [[Zip]] node by rewriting it into a single [[Project]] over the shared base plan. + * + * Both children of Zip must derive from the same base plan through chains of Project nodes. + * This rule: + * 1. Waits for both children to be resolved + * 2. Strips Project layers from each side to find the base plan + * 3. Verifies the base plans produce the same result (via `sameResult`) + * 4. Remaps the right side's attribute references to the left base plan's output + * 5. Produces a single Project that combines both sides' expressions + * + * If the base plans do not match, the Zip node remains unresolved and CheckAnalysis + * will report a [[ZIP_PLANS_NOT_MERGEABLE]] error. + */ +object ResolveZip extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(ZIP), ruleId) { + case z: Zip if z.childrenResolved => + val (leftExprs, leftBase) = extractProjectAndBase(z.left) + val (rightExprs, rightBase) = extractProjectAndBase(z.right) + if (leftBase.sameResult(rightBase)) { + // Build an attribute mapping from rightBase output to leftBase output (by position) + val attrMapping = AttributeMap(rightBase.output.zip(leftBase.output)) + // Remap right expressions to reference leftBase's attributes + val remappedRightExprs = rightExprs.map { expr => + expr.transform { + case a: Attribute => attrMapping.getOrElse(a, a) + }.asInstanceOf[NamedExpression] + } + Project(leftExprs ++ remappedRightExprs, leftBase) + } else { + z + } + } + + private def extractProjectAndBase( + plan: LogicalPlan): (Seq[NamedExpression], LogicalPlan) = plan match { + case Project(projectList, child) => (projectList, child) + case other => (other.output, other) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c18b7fcecc484..e151435b03961 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -829,6 +829,29 @@ case class Join( newLeft: LogicalPlan, newRight: LogicalPlan): Join = copy(left = newLeft, right = newRight) } +/** + * A logical plan that combines the columns of two DataFrames that derive from the same + * base plan through chains of Project nodes. This node is always unresolved and must be + * rewritten by [[ResolveZip]] into a single Project over the shared base plan during + * analysis. If the two children do not share the same base plan (after stripping Project + * nodes), analysis will fail with an error. + */ +case class Zip(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output + + override def maxRows: Option[Long] = left.maxRows + + override def maxRowsPerPartition: Option[Long] = left.maxRowsPerPartition + + final override val nodePatterns: Seq[TreePattern] = Seq(ZIP) + + // Always unresolved -- must be rewritten by ResolveZip during analysis. + override lazy val resolved: Boolean = false + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): Zip = copy(left = newLeft, right = newRight) +} + /** * Insert query result into a directory. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 1e718c02f5ea5..1dfb45af593e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -104,6 +104,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveTableSpec" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: + "org.apache.spark.sql.catalyst.analysis.ResolveZip" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnresolvedHaving" :: "org.apache.spark.sql.catalyst.analysis.ResolveUpdateEventTimeWatermarkColumn" :: "org.apache.spark.sql.catalyst.analysis.ResolveWindowTime" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 1e22c1ce86539..b3d96da1cb52a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -176,6 +176,7 @@ object TreePattern extends Enumeration { val TRANSPOSE: Value = Value val UNION: Value = Value val UNPIVOT: Value = Value + val ZIP: Value = Value val UPDATE_EVENT_TIME_WATERMARK_COLUMN: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveZipSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveZipSuite.scala new file mode 100644 index 0000000000000..5e1235cfec231 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveZipSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ResolveZipSuite extends AnalysisTest { + + private val base = LocalRelation($"a".int, $"b".int, $"c".int) + + object Resolve extends RuleExecutor[LogicalPlan] { + override val batches: Seq[Batch] = Seq( + Batch("ResolveZip", Once, ResolveZip)) + } + + test("resolve Zip: both sides have Project over same base") { + val left = Project(Seq(base.output(0)), base) + val right = Project(Seq(base.output(1)), base) + val zip = Zip(left, right) + + val resolved = Resolve.execute(zip) + val expected = Project(Seq(base.output(0), base.output(1)), base) + comparePlans(resolved, expected) + } + + test("resolve Zip: left is bare plan, right has Project") { + val right = Project(Seq(base.output(0)), base) + val zip = Zip(base, right) + + val resolved = Resolve.execute(zip) + val expected = Project(base.output ++ Seq(base.output(0)), base) + comparePlans(resolved, expected) + } + + test("resolve Zip: both sides are bare same plan") { + val zip = Zip(base, base) + + val resolved = Resolve.execute(zip) + val expected = Project(base.output ++ base.output, base) + comparePlans(resolved, expected) + } + + test("resolve Zip: both sides have expressions over same base") { + val left = base.select(($"a" + 1).as("a_plus_1")) + val right = base.select(($"b" * 2).as("b_times_2")) + val zip = Zip(left.analyze, right.analyze) + + val resolved = Resolve.execute(zip) + assert(!resolved.isInstanceOf[Zip], "Zip should have been resolved to a Project") + assert(resolved.isInstanceOf[Project]) + assert(resolved.output.length == 2) + assert(resolved.output(0).name == "a_plus_1") + assert(resolved.output(1).name == "b_times_2") + } + + test("resolve Zip: different base plans - Zip remains unresolved") { + val base2 = LocalRelation($"x".int, $"y".int, $"z".int, $"w".int) + val left = Project(Seq(base.output(0)), base) + val right = Project(Seq(base2.output(0)), base2) + val zip = Zip(left, right) + + val resolved = Resolve.execute(zip) + // ResolveZip cannot merge, so Zip stays + assert(resolved.isInstanceOf[Zip]) + } + + test("resolve Zip: skipped when children are unresolved") { + val unresolvedChild = Project( + Seq(UnresolvedAttribute("a")), + UnresolvedRelation(Seq("t"))) + val zip = Zip(unresolvedChild, unresolvedChild) + + val result = Resolve.execute(zip) + // Zip should remain unchanged because children are not resolved + assert(result.isInstanceOf[Zip]) + } + + test("CheckAnalysis: different base plans throws ZIP_PLANS_NOT_MERGEABLE") { + val base2 = LocalRelation($"x".int, $"y".int, $"z".int, $"w".int) + val left = Project(Seq(base.output(0)), base) + val right = Project(Seq(base2.output(0)), base2) + val zip = Zip(left, right) + + assertAnalysisErrorCondition( + zip, + expectedErrorCondition = "ZIP_PLANS_NOT_MERGEABLE", + expectedMessageParameters = Map.empty + ) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 84b356855710a..33780931afaf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -707,6 +707,19 @@ class Dataset[T] private[sql]( Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) } + /** + * Combines the columns of this DataFrame with another DataFrame that derives from the same + * base plan through Project operations. The optimizer rewrites the resulting Zip node into a + * single Project over the shared base plan. + * + * @param other another DataFrame that shares the same base plan + * @return a new DataFrame with columns from both sides + * @throws AnalysisException if the two DataFrames do not derive from the same base plan + */ + def zip(other: sql.Dataset[_]): DataFrame = withPlan { + Zip(logicalPlan, other.logicalPlan) + } + /** @inheritdoc */ def joinWith[U](other: sql.Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameZipSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameZipSuite.scala new file mode 100644 index 0000000000000..bf5b2fdcf1eb0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameZipSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.test.SharedSparkSession + +class DataFrameZipSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("zip: select different columns from the same DataFrame") { + val df = Seq((1, 2, 3), (4, 5, 6), (7, 8, 9)).toDF("a", "b", "c") + val left = df.select("a") + val right = df.select("b") + + checkAnswer( + left.zip(right), + Row(1, 2) :: Row(4, 5) :: Row(7, 8) :: Nil) + } + + test("zip: select with expressions over the same DataFrame") { + val df = Seq((1, 10), (2, 20), (3, 30)).toDF("a", "b") + val left = df.select(($"a" + 1).as("a_plus_1")) + val right = df.select(($"b" * 2).as("b_times_2")) + + checkAnswer( + left.zip(right), + Row(2, 20) :: Row(3, 40) :: Row(4, 60) :: Nil) + } + + test("zip: one side selects all columns") { + val df = Seq((1, 2), (3, 4)).toDF("a", "b") + val right = df.select(($"a" + $"b").as("sum")) + + checkAnswer( + df.zip(right), + Row(1, 2, 3) :: Row(3, 4, 7) :: Nil) + } + + test("zip: resolved plan is a Project") { + val df = Seq((1, 2)).toDF("a", "b") + val left = df.select("a") + val right = df.select("b") + val zipped = left.zip(right) + + assert(zipped.queryExecution.analyzed.isInstanceOf[Project]) + } + + test("zip: different base plans throws AnalysisException") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((3, 4, 5)).toDF("x", "y", "z") + + checkError( + exception = intercept[AnalysisException] { + df1.select("a").zip(df2.select("x")).queryExecution.assertAnalyzed() + }, + condition = "ZIP_PLANS_NOT_MERGEABLE" + ) + } + + test("zip: different base plans from spark.range throws AnalysisException") { + val df1 = spark.range(10).toDF("id1") + val df2 = spark.range(20).toDF("id2") + + checkError( + exception = intercept[AnalysisException] { + df1.zip(df2).queryExecution.assertAnalyzed() + }, + condition = "ZIP_PLANS_NOT_MERGEABLE" + ) + } +}