@@ -21,18 +21,14 @@ import scala.collection.immutable.SortedMap
2121
2222import org .apache .spark .OneToOneDependency
2323import org .apache .spark .Partition
24+ import org .apache .spark .internal .Logging
2425import org .apache .spark .sql .auron .NativeConverters
2526import org .apache .spark .sql .auron .NativeHelper
2627import org .apache .spark .sql .auron .NativeRDD
2728import org .apache .spark .sql .auron .NativeSupports
2829import org .apache .spark .sql .auron .Shims
2930import org .apache .spark .sql .catalyst .expressions .Expression
30- import org .apache .spark .sql .catalyst .plans .FullOuter
31- import org .apache .spark .sql .catalyst .plans .JoinType
32- import org .apache .spark .sql .catalyst .plans .LeftAnti
33- import org .apache .spark .sql .catalyst .plans .LeftOuter
34- import org .apache .spark .sql .catalyst .plans .LeftSemi
35- import org .apache .spark .sql .catalyst .plans .RightOuter
31+ import org .apache .spark .sql .catalyst .plans .{ExistenceJoin , FullOuter , JoinType , LeftAnti , LeftOuter , LeftSemi , RightOuter }
3632import org .apache .spark .sql .catalyst .plans .physical .Partitioning
3733import org .apache .spark .sql .execution .BinaryExecNode
3834import org .apache .spark .sql .execution .SparkPlan
@@ -43,7 +39,7 @@ import org.apache.spark.sql.types.LongType
4339
4440import org .apache .auron .{protobuf => pb }
4541import org .apache .auron .metric .SparkMetricNode
46- import org .apache .auron .protobuf .JoinOn
42+ import org .apache .auron .protobuf .{ EmptyPartitionsExecNode , JoinOn , PhysicalPlanNode }
4743
4844abstract class NativeBroadcastJoinBase (
4945 override val left : SparkPlan ,
@@ -55,7 +51,8 @@ abstract class NativeBroadcastJoinBase(
5551 broadcastSide : BroadcastSide ,
5652 isNullAwareAntiJoin : Boolean )
5753 extends BinaryExecNode
58- with NativeSupports {
54+ with NativeSupports
55+ with Logging {
5956
6057 override lazy val metrics : Map [String , SQLMetric ] = SortedMap [String , SQLMetric ]() ++ Map (
6158 NativeHelper
@@ -127,44 +124,99 @@ abstract class NativeBroadcastJoinBase(
127124 override def doExecuteNative (): NativeRDD = {
128125 val leftRDD = NativeHelper .executeNative(left)
129126 val rightRDD = NativeHelper .executeNative(right)
130- val nativeMetrics = SparkMetricNode (metrics, leftRDD.metrics :: rightRDD.metrics :: Nil )
131- val nativeSchema = this .nativeSchema
132- val nativeJoinType = this .nativeJoinType
133- val nativeJoinOn = this .nativeJoinOn
134127
135128 val (probedRDD, builtRDD) = broadcastSide match {
136129 case BroadcastLeft => (rightRDD, leftRDD)
137130 case BroadcastRight => (leftRDD, rightRDD)
138131 }
139132
133+ // Handle the edge case when probed side is empty (no partitions)
134+ // This matches Spark's BroadcastNestedLoopJoinExec behavior for condition.isEmpty case:
135+ // val streamExists = !streamed.executeTake(1).isEmpty
136+ // if (streamExists == exists) sparkContext.makeRDD(relation.value)
137+ // else sparkContext.emptyRDD
138+ // where exists = true for Semi, false for Anti
139+ //
140+ // Note: This optimization only applies to Semi/Anti joins.
141+ if (probedRDD.partitions.isEmpty) {
142+ joinType match {
143+ case LeftAnti =>
144+ return builtRDD
145+ case LeftSemi =>
146+ return probedRDD
147+ case _ =>
148+ }
149+ }
150+
151+ val nativeMetrics = SparkMetricNode (metrics, leftRDD.metrics :: rightRDD.metrics :: Nil )
152+ val nativeSchema = this .nativeSchema
153+ val nativeJoinType = this .nativeJoinType
154+ val nativeJoinOn = this .nativeJoinOn
155+
140156 val probedShuffleReadFull = probedRDD.isShuffleReadFull && (broadcastSide match {
141157 case BroadcastLeft =>
142158 Seq (FullOuter , RightOuter ).contains(joinType)
143159 case BroadcastRight =>
144160 Seq (FullOuter , LeftOuter , LeftSemi , LeftAnti ).contains(joinType)
145161 })
146162
163+ // For ExistenceJoin with empty probed side, use builtRDD.partitions to ensure
164+ // native join can execute and finish() will output all build rows with exists=false
165+ val (rddPartitions, rddPartitioner, rddDependencies) =
166+ if (probedRDD.partitions.isEmpty && joinType.isInstanceOf [ExistenceJoin ]) {
167+ (builtRDD.partitions, builtRDD.partitioner, new OneToOneDependency (builtRDD) :: Nil )
168+ } else {
169+ (probedRDD.partitions, probedRDD.partitioner, new OneToOneDependency (probedRDD) :: Nil )
170+ }
171+
147172 new NativeRDD (
148173 sparkContext,
149174 nativeMetrics,
150- probedRDD.partitions ,
151- rddPartitioner = probedRDD.partitioner ,
152- rddDependencies = new OneToOneDependency (probedRDD) :: Nil ,
175+ rddPartitions ,
176+ rddPartitioner = rddPartitioner ,
177+ rddDependencies = rddDependencies ,
153178 probedShuffleReadFull,
154179 (partition, context) => {
155180 val partition0 = new Partition () {
156181 override def index : Int = 0
157182 }
158- val (leftChild, rightChild) = broadcastSide match {
159- case BroadcastLeft =>
160- (
161- leftRDD.nativePlan(partition0, context),
162- rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
163- case BroadcastRight =>
164- (
165- leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
166- rightRDD.nativePlan(partition0, context))
167- }
183+ val (leftChild, rightChild) =
184+ if (probedRDD.partitions.isEmpty && joinType.isInstanceOf [ExistenceJoin ]) {
185+ val probedSchema = broadcastSide match {
186+ case BroadcastLeft => Util .getNativeSchema(right.output)
187+ case BroadcastRight => Util .getNativeSchema(left.output)
188+ }
189+ val emptyProbedPlan = PhysicalPlanNode
190+ .newBuilder()
191+ .setEmptyPartitions(
192+ EmptyPartitionsExecNode
193+ .newBuilder()
194+ .setNumPartitions(1 )
195+ .setSchema(probedSchema)
196+ .build())
197+ .build()
198+ broadcastSide match {
199+ case BroadcastLeft =>
200+ (
201+ leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
202+ emptyProbedPlan)
203+ case BroadcastRight =>
204+ (
205+ emptyProbedPlan,
206+ rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
207+ }
208+ } else {
209+ broadcastSide match {
210+ case BroadcastLeft =>
211+ (
212+ leftRDD.nativePlan(partition0, context),
213+ rightRDD.nativePlan(rightRDD.partitions(partition.index), context))
214+ case BroadcastRight =>
215+ (
216+ leftRDD.nativePlan(leftRDD.partitions(partition.index), context),
217+ rightRDD.nativePlan(partition0, context))
218+ }
219+ }
168220 val cachedBuildHashMapId = s " bhm_stage ${context.stageId}_rdd ${builtRDD.id}"
169221
170222 val broadcastJoinExec = pb.BroadcastJoinExecNode
0 commit comments