Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -8331,6 +8331,12 @@
],
"sqlState" : "42KDF"
},
"ZIP_PLANS_NOT_MERGEABLE" : {
"message" : [
"The two DataFrames in zip() cannot be merged because they do not derive from the same base plan through Project operations."
],
"sqlState" : "42K03"
},
"_LEGACY_ERROR_TEMP_0001" : {
"message" : [
"Invalid InsertIntoContext."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ class Analyzer(
ResolveBinaryArithmetic ::
new ResolveIdentifierClause(earlyBatches) ::
ResolveUnion ::
ResolveZip ::
FlattenSequentialStreamingUnion ::
ValidateSequentialStreamingUnion ::
ResolveRowLevelCommandAssignments ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,19 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
messageParameters = Map.empty)
}

case z: Zip =>
def stripProjects(plan: LogicalPlan): LogicalPlan = plan match {
case Project(_, child) => stripProjects(child)
case other => other
}
val leftBase = stripProjects(z.left)
val rightBase = stripProjects(z.right)
if (!leftBase.sameResult(rightBase)) {
z.failAnalysis(
errorClass = "ZIP_PLANS_NOT_MERGEABLE",
messageParameters = Map.empty)
}

case a: Aggregate =>
a.groupingExpressions.foreach(
expression =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Zip}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.ZIP

/**
* Resolves a [[Zip]] node by rewriting it into a single [[Project]] over the shared base plan.
*
* Both children of Zip must derive from the same base plan through chains of Project nodes.
* This rule:
* 1. Waits for both children to be resolved
* 2. Strips Project layers from each side to find the base plan
* 3. Verifies the base plans produce the same result (via `sameResult`)
* 4. Remaps the right side's attribute references to the left base plan's output
* 5. Produces a single Project that combines both sides' expressions
*
* If the base plans do not match, the Zip node remains unresolved and CheckAnalysis
* will report a [[ZIP_PLANS_NOT_MERGEABLE]] error.
*/
object ResolveZip extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(ZIP), ruleId) {
case z: Zip if z.childrenResolved =>
val (leftExprs, leftBase) = extractProjectAndBase(z.left)
val (rightExprs, rightBase) = extractProjectAndBase(z.right)
if (leftBase.sameResult(rightBase)) {
// Build an attribute mapping from rightBase output to leftBase output (by position)
val attrMapping = AttributeMap(rightBase.output.zip(leftBase.output))
// Remap right expressions to reference leftBase's attributes
val remappedRightExprs = rightExprs.map { expr =>
expr.transform {
case a: Attribute => attrMapping.getOrElse(a, a)
}.asInstanceOf[NamedExpression]
}
Project(leftExprs ++ remappedRightExprs, leftBase)
} else {
z
}
}

private def extractProjectAndBase(
plan: LogicalPlan): (Seq[NamedExpression], LogicalPlan) = plan match {
case Project(projectList, child) => (projectList, child)
case other => (other.output, other)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,29 @@ case class Join(
newLeft: LogicalPlan, newRight: LogicalPlan): Join = copy(left = newLeft, right = newRight)
}

/**
* A logical plan that combines the columns of two DataFrames that derive from the same
* base plan through chains of Project nodes. This node is always unresolved and must be
* rewritten by [[ResolveZip]] into a single Project over the shared base plan during
* analysis. If the two children do not share the same base plan (after stripping Project
* nodes), analysis will fail with an error.
*/
case class Zip(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output ++ right.output

override def maxRows: Option[Long] = left.maxRows

override def maxRowsPerPartition: Option[Long] = left.maxRowsPerPartition

final override val nodePatterns: Seq[TreePattern] = Seq(ZIP)

// Always unresolved -- must be rewritten by ResolveZip during analysis.
override lazy val resolved: Boolean = false

override protected def withNewChildrenInternal(
newLeft: LogicalPlan, newRight: LogicalPlan): Zip = copy(left = newLeft, right = newRight)
}

/**
* Insert query result into a directory.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.ResolveTableSpec" ::
"org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" ::
"org.apache.spark.sql.catalyst.analysis.ResolveUnion" ::
"org.apache.spark.sql.catalyst.analysis.ResolveZip" ::
"org.apache.spark.sql.catalyst.analysis.ResolveUnresolvedHaving" ::
"org.apache.spark.sql.catalyst.analysis.ResolveUpdateEventTimeWatermarkColumn" ::
"org.apache.spark.sql.catalyst.analysis.ResolveWindowTime" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ object TreePattern extends Enumeration {
val TRANSPOSE: Value = Value
val UNION: Value = Value
val UNPIVOT: Value = Value
val ZIP: Value = Value
val UPDATE_EVENT_TIME_WATERMARK_COLUMN: Value = Value
val TYPED_FILTER: Value = Value
val WINDOW: Value = Value
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ResolveZipSuite extends AnalysisTest {

private val base = LocalRelation($"a".int, $"b".int, $"c".int)

object Resolve extends RuleExecutor[LogicalPlan] {
override val batches: Seq[Batch] = Seq(
Batch("ResolveZip", Once, ResolveZip))
}

test("resolve Zip: both sides have Project over same base") {
val left = Project(Seq(base.output(0)), base)
val right = Project(Seq(base.output(1)), base)
val zip = Zip(left, right)

val resolved = Resolve.execute(zip)
val expected = Project(Seq(base.output(0), base.output(1)), base)
comparePlans(resolved, expected)
}

test("resolve Zip: left is bare plan, right has Project") {
val right = Project(Seq(base.output(0)), base)
val zip = Zip(base, right)

val resolved = Resolve.execute(zip)
val expected = Project(base.output ++ Seq(base.output(0)), base)
comparePlans(resolved, expected)
}

test("resolve Zip: both sides are bare same plan") {
val zip = Zip(base, base)

val resolved = Resolve.execute(zip)
val expected = Project(base.output ++ base.output, base)
comparePlans(resolved, expected)
}

test("resolve Zip: both sides have expressions over same base") {
val left = base.select(($"a" + 1).as("a_plus_1"))
val right = base.select(($"b" * 2).as("b_times_2"))
val zip = Zip(left.analyze, right.analyze)

val resolved = Resolve.execute(zip)
assert(!resolved.isInstanceOf[Zip], "Zip should have been resolved to a Project")
assert(resolved.isInstanceOf[Project])
assert(resolved.output.length == 2)
assert(resolved.output(0).name == "a_plus_1")
assert(resolved.output(1).name == "b_times_2")
}

test("resolve Zip: different base plans - Zip remains unresolved") {
val base2 = LocalRelation($"x".int, $"y".int, $"z".int, $"w".int)
val left = Project(Seq(base.output(0)), base)
val right = Project(Seq(base2.output(0)), base2)
val zip = Zip(left, right)

val resolved = Resolve.execute(zip)
// ResolveZip cannot merge, so Zip stays
assert(resolved.isInstanceOf[Zip])
}

test("resolve Zip: skipped when children are unresolved") {
val unresolvedChild = Project(
Seq(UnresolvedAttribute("a")),
UnresolvedRelation(Seq("t")))
val zip = Zip(unresolvedChild, unresolvedChild)

val result = Resolve.execute(zip)
// Zip should remain unchanged because children are not resolved
assert(result.isInstanceOf[Zip])
}

test("CheckAnalysis: different base plans throws ZIP_PLANS_NOT_MERGEABLE") {
val base2 = LocalRelation($"x".int, $"y".int, $"z".int, $"w".int)
val left = Project(Seq(base.output(0)), base)
val right = Project(Seq(base2.output(0)), base2)
val zip = Zip(left, right)

assertAnalysisErrorCondition(
zip,
expectedErrorCondition = "ZIP_PLANS_NOT_MERGEABLE",
expectedMessageParameters = Map.empty
)
}
}
13 changes: 13 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,19 @@ class Dataset[T] private[sql](
Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE)
}

/**
* Combines the columns of this DataFrame with another DataFrame that derives from the same
* base plan through Project operations. The optimizer rewrites the resulting Zip node into a
* single Project over the shared base plan.
*
* @param other another DataFrame that shares the same base plan
* @return a new DataFrame with columns from both sides
* @throws AnalysisException if the two DataFrames do not derive from the same base plan
*/
def zip(other: sql.Dataset[_]): DataFrame = withPlan {
Zip(logicalPlan, other.logicalPlan)
}

/** @inheritdoc */
def joinWith[U](other: sql.Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
// Creates a Join node and resolve it first, to get join condition resolved, self-join resolved,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.test.SharedSparkSession

class DataFrameZipSuite extends QueryTest with SharedSparkSession {
import testImplicits._

test("zip: select different columns from the same DataFrame") {
val df = Seq((1, 2, 3), (4, 5, 6), (7, 8, 9)).toDF("a", "b", "c")
val left = df.select("a")
val right = df.select("b")

checkAnswer(
left.zip(right),
Row(1, 2) :: Row(4, 5) :: Row(7, 8) :: Nil)
}

test("zip: select with expressions over the same DataFrame") {
val df = Seq((1, 10), (2, 20), (3, 30)).toDF("a", "b")
val left = df.select(($"a" + 1).as("a_plus_1"))
val right = df.select(($"b" * 2).as("b_times_2"))

checkAnswer(
left.zip(right),
Row(2, 20) :: Row(3, 40) :: Row(4, 60) :: Nil)
}

test("zip: one side selects all columns") {
val df = Seq((1, 2), (3, 4)).toDF("a", "b")
val right = df.select(($"a" + $"b").as("sum"))

checkAnswer(
df.zip(right),
Row(1, 2, 3) :: Row(3, 4, 7) :: Nil)
}

test("zip: resolved plan is a Project") {
val df = Seq((1, 2)).toDF("a", "b")
val left = df.select("a")
val right = df.select("b")
val zipped = left.zip(right)

assert(zipped.queryExecution.analyzed.isInstanceOf[Project])
}

test("zip: different base plans throws AnalysisException") {
val df1 = Seq((1, 2)).toDF("a", "b")
val df2 = Seq((3, 4, 5)).toDF("x", "y", "z")

checkError(
exception = intercept[AnalysisException] {
df1.select("a").zip(df2.select("x")).queryExecution.assertAnalyzed()
},
condition = "ZIP_PLANS_NOT_MERGEABLE"
)
}

test("zip: different base plans from spark.range throws AnalysisException") {
val df1 = spark.range(10).toDF("id1")
val df2 = spark.range(20).toDF("id2")

checkError(
exception = intercept[AnalysisException] {
df1.zip(df2).queryExecution.assertAnalyzed()
},
condition = "ZIP_PLANS_NOT_MERGEABLE"
)
}
}