Skip to content

Commit aa28963

Browse files
committed
fix
1 parent 5a7f3e5 commit aa28963

3 files changed

Lines changed: 101 additions & 25 deletions

File tree

auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.auron.utils
1818

1919
import org.apache.spark.sql._
20+
import org.apache.spark.sql.execution.joins.AuronExistenceJoinSuite
2021

2122
class AuronSparkTestSettings extends SparkTestSettings {
2223
{
@@ -42,6 +43,8 @@ class AuronSparkTestSettings extends SparkTestSettings {
4243

4344
enableSuite[AuronTypedImperativeAggregateSuite]
4445

46+
enableSuite[AuronExistenceJoinSuite]
47+
4548
// Will be implemented in the future.
4649
override def getSQLQueryTestSettings = new SQLQueryTestSettings {
4750
override def getResourceFilePath: String = ???
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.joins
18+
19+
import org.apache.spark.sql.SparkTestsSharedSessionBase
20+
21+
class AuronExistenceJoinSuite extends ExistenceJoinSuite with SparkTestsSharedSessionBase {}

spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeBroadcastJoinBase.scala

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,14 @@ import scala.collection.immutable.SortedMap
2121

2222
import org.apache.spark.OneToOneDependency
2323
import org.apache.spark.Partition
24+
import org.apache.spark.internal.Logging
2425
import org.apache.spark.sql.auron.NativeConverters
2526
import org.apache.spark.sql.auron.NativeHelper
2627
import org.apache.spark.sql.auron.NativeRDD
2728
import org.apache.spark.sql.auron.NativeSupports
2829
import org.apache.spark.sql.auron.Shims
2930
import org.apache.spark.sql.catalyst.expressions.Expression
30-
import org.apache.spark.sql.catalyst.plans.FullOuter
31-
import org.apache.spark.sql.catalyst.plans.JoinType
32-
import org.apache.spark.sql.catalyst.plans.LeftAnti
33-
import org.apache.spark.sql.catalyst.plans.LeftOuter
34-
import org.apache.spark.sql.catalyst.plans.LeftSemi
35-
import org.apache.spark.sql.catalyst.plans.RightOuter
31+
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
3632
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
3733
import org.apache.spark.sql.execution.BinaryExecNode
3834
import org.apache.spark.sql.execution.SparkPlan
@@ -43,7 +39,7 @@ import org.apache.spark.sql.types.LongType
4339

4440
import org.apache.auron.{protobuf => pb}
4541
import org.apache.auron.metric.SparkMetricNode
46-
import org.apache.auron.protobuf.JoinOn
42+
import org.apache.auron.protobuf.{EmptyPartitionsExecNode, JoinOn, PhysicalPlanNode}
4743

4844
abstract class NativeBroadcastJoinBase(
4945
override val left: SparkPlan,
@@ -55,7 +51,8 @@ abstract class NativeBroadcastJoinBase(
5551
broadcastSide: BroadcastSide,
5652
isNullAwareAntiJoin: Boolean)
5753
extends BinaryExecNode
58-
with NativeSupports {
54+
with NativeSupports
55+
with Logging {
5956

6057
override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map(
6158
NativeHelper
@@ -127,44 +124,99 @@ abstract class NativeBroadcastJoinBase(
127124
override def doExecuteNative(): NativeRDD = {
128125
val leftRDD = NativeHelper.executeNative(left)
129126
val rightRDD = NativeHelper.executeNative(right)
130-
val nativeMetrics = SparkMetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil)
131-
val nativeSchema = this.nativeSchema
132-
val nativeJoinType = this.nativeJoinType
133-
val nativeJoinOn = this.nativeJoinOn
134127

135128
val (probedRDD, builtRDD) = broadcastSide match {
136129
case BroadcastLeft => (rightRDD, leftRDD)
137130
case BroadcastRight => (leftRDD, rightRDD)
138131
}
139132

133+
// Handle the edge case when probed side is empty (no partitions)
134+
// This matches Spark's BroadcastNestedLoopJoinExec behavior for condition.isEmpty case:
135+
// val streamExists = !streamed.executeTake(1).isEmpty
136+
// if (streamExists == exists) sparkContext.makeRDD(relation.value)
137+
// else sparkContext.emptyRDD
138+
// where exists = true for Semi, false for Anti
139+
//
140+
// Note: This optimization only applies to Semi/Anti joins.
141+
if (probedRDD.partitions.isEmpty) {
142+
joinType match {
143+
case LeftAnti =>
144+
return builtRDD
145+
case LeftSemi =>
146+
return probedRDD
147+
case _ =>
148+
}
149+
}
150+
151+
val nativeMetrics = SparkMetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil)
152+
val nativeSchema = this.nativeSchema
153+
val nativeJoinType = this.nativeJoinType
154+
val nativeJoinOn = this.nativeJoinOn
155+
140156
val probedShuffleReadFull = probedRDD.isShuffleReadFull && (broadcastSide match {
141157
case BroadcastLeft =>
142158
Seq(FullOuter, RightOuter).contains(joinType)
143159
case BroadcastRight =>
144160
Seq(FullOuter, LeftOuter, LeftSemi, LeftAnti).contains(joinType)
145161
})
146162

163+
// For ExistenceJoin with empty probed side, use builtRDD.partitions to ensure
164+
// native join can execute and finish() will output all build rows with exists=false
165+
val (rddPartitions, rddPartitioner, rddDependencies) =
166+
if (probedRDD.partitions.isEmpty && joinType.isInstanceOf[ExistenceJoin]) {
167+
(builtRDD.partitions, builtRDD.partitioner, new OneToOneDependency(builtRDD) :: Nil)
168+
} else {
169+
(probedRDD.partitions, probedRDD.partitioner, new OneToOneDependency(probedRDD) :: Nil)
170+
}
171+
147172
new NativeRDD(
148173
sparkContext,
149174
nativeMetrics,
150-
probedRDD.partitions,
151-
rddPartitioner = probedRDD.partitioner,
152-
rddDependencies = new OneToOneDependency(probedRDD) :: Nil,
175+
rddPartitions,
176+
rddPartitioner = rddPartitioner,
177+
rddDependencies = rddDependencies,
153178
probedShuffleReadFull,
154179
(partition, context) => {
155180
val partition0 = new Partition() {
156181
override def index: Int = 0
157182
}
158-
val (leftChild, rightChild) = broadcastSide match {
159-
case BroadcastLeft =>
160-
(
161-
leftRDD.nativePlan(partition0, context),
162-
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
163-
case BroadcastRight =>
164-
(
165-
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
166-
rightRDD.nativePlan(partition0, context))
167-
}
183+
val (leftChild, rightChild) =
184+
if (probedRDD.partitions.isEmpty && joinType.isInstanceOf[ExistenceJoin]) {
185+
val probedSchema = broadcastSide match {
186+
case BroadcastLeft => Util.getNativeSchema(right.output)
187+
case BroadcastRight => Util.getNativeSchema(left.output)
188+
}
189+
val emptyProbedPlan = PhysicalPlanNode
190+
.newBuilder()
191+
.setEmptyPartitions(
192+
EmptyPartitionsExecNode
193+
.newBuilder()
194+
.setNumPartitions(1)
195+
.setSchema(probedSchema)
196+
.build())
197+
.build()
198+
broadcastSide match {
199+
case BroadcastLeft =>
200+
(
201+
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
202+
emptyProbedPlan)
203+
case BroadcastRight =>
204+
(
205+
emptyProbedPlan,
206+
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
207+
}
208+
} else {
209+
broadcastSide match {
210+
case BroadcastLeft =>
211+
(
212+
leftRDD.nativePlan(partition0, context),
213+
rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
214+
case BroadcastRight =>
215+
(
216+
leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
217+
rightRDD.nativePlan(partition0, context))
218+
}
219+
}
168220
val cachedBuildHashMapId = s"bhm_stage${context.stageId}_rdd${builtRDD.id}"
169221

170222
val broadcastJoinExec = pb.BroadcastJoinExecNode

0 commit comments

Comments
 (0)