Skip to content

Commit 8a69b5a

Browse files
Dandandanclaude
andcommitted
Fix channel-based multi-output routing and output partitioning
- Fix ProducingOutput to carry partition index alongside batch, ensuring correct routing when emit_next_partition eagerly advances the index - Add use_channels() method to centralize the decision of when to use the channel-based multi-output path - Add update_cache_partitioning() to keep output partitioning in sync when limit_options or num_agg_partitions change - Fix with_new_limit_options to recalculate output partitioning (prevents TopK aggregation from claiming Hash partitioning when channels won't be used) - Guard CombinePartialFinalAggregate from combining when Partial has num_agg_partitions > 1 - Set num_agg_partitions in physical planner when repartitioning is enabled - Keep spawned task references alive via Arc in output streams to prevent abort-on-drop Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 27fb7c8 commit 8a69b5a

4 files changed

Lines changed: 152 additions & 156 deletions

File tree

datafusion/core/src/physical_planner.rs

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,19 +1043,39 @@ impl DefaultPhysicalPlanner {
10431043
input_exec
10441044
};
10451045

1046-
let initial_aggr = Arc::new(AggregateExec::try_new(
1047-
AggregateMode::Partial,
1048-
groups.clone(),
1049-
aggregates,
1050-
filters.clone(),
1051-
input_exec,
1052-
Arc::clone(&physical_input_schema),
1053-
)?);
1054-
10551046
let can_repartition = !groups.is_empty()
10561047
&& session_state.config().target_partitions() > 1
10571048
&& session_state.config().repartition_aggregations();
10581049

1050+
let initial_aggr = if can_repartition {
1051+
// When repartitioning is enabled, set num_agg_partitions
1052+
// so the partial aggregate produces target_partitions
1053+
// output partitions directly (eliminating the need for
1054+
// a downstream RepartitionExec).
1055+
Arc::new(
1056+
AggregateExec::try_new(
1057+
AggregateMode::Partial,
1058+
groups.clone(),
1059+
aggregates,
1060+
filters.clone(),
1061+
input_exec,
1062+
Arc::clone(&physical_input_schema),
1063+
)?
1064+
.with_num_agg_partitions(
1065+
session_state.config().target_partitions(),
1066+
),
1067+
)
1068+
} else {
1069+
Arc::new(AggregateExec::try_new(
1070+
AggregateMode::Partial,
1071+
groups.clone(),
1072+
aggregates,
1073+
filters.clone(),
1074+
input_exec,
1075+
Arc::clone(&physical_input_schema),
1076+
)?)
1077+
};
1078+
10591079
// Some aggregators may be modified during initialization for
10601080
// optimization purposes. For example, a FIRST_VALUE may turn
10611081
// into a LAST_VALUE with the reverse ordering requirement.

datafusion/physical-optimizer/src/combine_partial_final_agg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
7373
};
7474

7575
let transformed = if *input_agg_exec.mode() == AggregateMode::Partial
76+
&& input_agg_exec.num_agg_partitions() <= 1
7677
&& can_combine(
7778
(
7879
agg_exec.group_expr(),

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,7 @@ enum AggRepartitionState {
634634
/// Input tasks have been spawned and are sending to output channels.
635635
Consuming {
636636
/// One receiver per output partition. `None` if already taken by `execute()`.
637-
receivers:
638-
Vec<Option<tokio::sync::mpsc::UnboundedReceiver<Result<RecordBatch>>>>,
637+
receivers: Vec<Option<tokio::sync::mpsc::UnboundedReceiver<Result<RecordBatch>>>>,
639638
/// Background tasks; dropped to abort if the exec is dropped.
640639
_abort_helper: Vec<SpawnedTask<()>>,
641640
},
@@ -745,7 +744,7 @@ impl AggregateExec {
745744

746745
/// Clone this exec, overriding only the limit hint.
747746
pub fn with_new_limit_options(&self, limit_options: Option<LimitOptions>) -> Self {
748-
Self {
747+
let mut new = Self {
749748
limit_options,
750749
// clone the rest of the fields
751750
required_input_ordering: self.required_input_ordering.clone(),
@@ -762,7 +761,9 @@ impl AggregateExec {
762761
dynamic_filter: self.dynamic_filter.clone(),
763762
num_agg_partitions: self.num_agg_partitions,
764763
repartition_state: Arc::new(Mutex::new(AggRepartitionState::default())),
765-
}
764+
};
765+
new.update_cache_partitioning();
766+
new
766767
}
767768

768769
pub fn cache(&self) -> &PlanProperties {
@@ -910,27 +911,55 @@ impl AggregateExec {
910911
&self.mode
911912
}
912913

914+
/// Number of internal hash table partitions.
915+
pub fn num_agg_partitions(&self) -> usize {
916+
self.num_agg_partitions
917+
}
918+
913919
/// Set the number of internal hash table partitions for partial aggregation.
914920
/// When n > 1, the aggregate also acts as a repartitioning operator,
915921
/// producing n output partitions.
916922
pub fn with_num_agg_partitions(mut self, n: usize) -> Self {
917923
self.num_agg_partitions = n;
918-
if n > 1 {
919-
// Update output partitioning to reflect the repartitioned output.
920-
let group_exprs = self.output_group_expr();
921-
let mut cache = (*self.cache).clone();
922-
cache.partitioning = Partitioning::Hash(group_exprs, n);
923-
self.cache = Arc::new(cache);
924-
}
924+
self.update_cache_partitioning();
925925
self
926926
}
927927

928928
/// Set the limit options for this AggExec
929929
pub fn with_limit_options(mut self, limit_options: Option<LimitOptions>) -> Self {
930930
self.limit_options = limit_options;
931+
self.update_cache_partitioning();
931932
self
932933
}
933934

935+
/// Returns true if the channel-based multi-output path will be used
936+
/// in execute().
937+
fn use_channels(&self) -> bool {
938+
self.num_agg_partitions > 1
939+
&& !self.group_by.is_true_no_grouping()
940+
&& (self.limit_options.is_none()
941+
|| self.is_unordered_unfiltered_group_by_distinct())
942+
}
943+
944+
/// Update the cached output partitioning based on whether channels
945+
/// will be used.
946+
fn update_cache_partitioning(&mut self) {
947+
if self.use_channels() {
948+
let group_exprs = self.output_group_expr();
949+
let mut cache = (*self.cache).clone();
950+
cache.partitioning = Partitioning::Hash(group_exprs, self.num_agg_partitions);
951+
self.cache = Arc::new(cache);
952+
} else if self.num_agg_partitions > 1 {
953+
// Channels won't be used (e.g. limit/TopK was set), revert
954+
// to single-partition output matching the non-channel path.
955+
let mut cache = (*self.cache).clone();
956+
cache.partitioning = Partitioning::UnknownPartitioning(
957+
self.input.output_partitioning().partition_count(),
958+
);
959+
self.cache = Arc::new(cache);
960+
}
961+
}
962+
934963
/// Get the limit options (if set)
935964
pub fn limit_options(&self) -> Option<LimitOptions> {
936965
self.limit_options
@@ -1028,8 +1057,7 @@ impl AggregateExec {
10281057
Err(e) => {
10291058
let e = Arc::new(e);
10301059
for sender in &senders {
1031-
let _ =
1032-
sender.send(Err(DataFusionError::from(&e)));
1060+
let _ = sender.send(Err(DataFusionError::from(&e)));
10331061
}
10341062
return;
10351063
}
@@ -1553,8 +1581,7 @@ impl ExecutionPlan for AggregateExec {
15531581
if self.num_agg_partitions > 1 {
15541582
let group_exprs = me.output_group_expr();
15551583
let mut cache = (*me.cache).clone();
1556-
cache.partitioning =
1557-
Partitioning::Hash(group_exprs, self.num_agg_partitions);
1584+
cache.partitioning = Partitioning::Hash(group_exprs, self.num_agg_partitions);
15581585
me.cache = Arc::new(cache);
15591586
}
15601587

@@ -1566,7 +1593,7 @@ impl ExecutionPlan for AggregateExec {
15661593
partition: usize,
15671594
context: Arc<TaskContext>,
15681595
) -> Result<SendableRecordBatchStream> {
1569-
if self.num_agg_partitions <= 1 {
1596+
if !self.use_channels() {
15701597
return self
15711598
.execute_typed(partition, &context)
15721599
.map(|stream| stream.into());
@@ -1593,18 +1620,12 @@ impl ExecutionPlan for AggregateExec {
15931620
// Spawn one task per input partition
15941621
let mut tasks = Vec::with_capacity(num_input);
15951622
for input_idx in 0..num_input {
1596-
let stream = GroupedHashAggregateStream::new(
1597-
self, &context, input_idx,
1598-
)?;
1599-
let senders_clone: Vec<_> =
1600-
senders.iter().map(|s| s.clone()).collect();
1623+
let stream = GroupedHashAggregateStream::new(self, &context, input_idx)?;
1624+
let senders_clone: Vec<_> = senders.iter().map(|s| s.clone()).collect();
16011625
let n = num_output;
16021626

1603-
let task = SpawnedTask::spawn(Self::pull_and_route(
1604-
stream,
1605-
n,
1606-
senders_clone,
1607-
));
1627+
let task =
1628+
SpawnedTask::spawn(Self::pull_and_route(stream, n, senders_clone));
16081629
tasks.push(task);
16091630
}
16101631

@@ -1615,8 +1636,7 @@ impl ExecutionPlan for AggregateExec {
16151636
}
16161637

16171638
// Take this partition's receiver
1618-
let AggRepartitionState::Consuming { receivers, .. } = &mut *state
1619-
else {
1639+
let AggRepartitionState::Consuming { receivers, .. } = &mut *state else {
16201640
unreachable!()
16211641
};
16221642

@@ -1625,12 +1645,17 @@ impl ExecutionPlan for AggregateExec {
16251645
.and_then(|r| r.take())
16261646
.expect("partition receiver already consumed");
16271647
let schema = Arc::clone(&self.schema);
1648+
// Keep a reference to the shared state so spawned tasks aren't
1649+
// aborted when the AggregateExec is dropped but streams are
1650+
// still being polled.
1651+
let _state_ref = Arc::clone(&self.repartition_state);
16281652
drop(state);
16291653

16301654
// Wrap the receiver as a RecordBatchStream
1631-
let stream = futures::stream::unfold(rx, |mut rx| async move {
1632-
rx.recv().await.map(|batch| (batch, rx))
1633-
});
1655+
let stream =
1656+
futures::stream::unfold((rx, _state_ref), |(mut rx, state_ref)| async move {
1657+
rx.recv().await.map(|batch| (batch, (rx, state_ref)))
1658+
});
16341659
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
16351660
}
16361661

@@ -4348,9 +4373,7 @@ mod tests {
43484373
Arc::clone(&schema),
43494374
vec![
43504375
Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 3])),
4351-
Arc::new(Float64Array::from(vec![
4352-
10.0, 20.0, 30.0, 40.0, 50.0, 60.0,
4353-
])),
4376+
Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])),
43544377
],
43554378
)?;
43564379
let batch2 = RecordBatch::try_new(
@@ -4364,25 +4387,17 @@ mod tests {
43644387

43654388
let task_ctx = Arc::new(TaskContext::default());
43664389

4367-
let group_expr =
4368-
vec![(col("group_col", &schema)?, "group_col".to_string())];
4390+
let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())];
43694391
let aggr_expr = vec![Arc::new(
4370-
AggregateExprBuilder::new(
4371-
sum_udaf(),
4372-
vec![col("value_col", &schema)?],
4373-
)
4374-
.schema(Arc::clone(&schema))
4375-
.alias("sum_value")
4376-
.build()?,
4392+
AggregateExprBuilder::new(sum_udaf(), vec![col("value_col", &schema)?])
4393+
.schema(Arc::clone(&schema))
4394+
.alias("sum_value")
4395+
.build()?,
43774396
)];
43784397

43794398
let num_output_partitions = 3;
43804399

4381-
let exec = TestMemoryExec::try_new(
4382-
&input_partitions,
4383-
Arc::clone(&schema),
4384-
None,
4385-
)?;
4400+
let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?;
43864401
let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec)));
43874402

43884403
let aggregate_exec = Arc::new(
@@ -4399,7 +4414,10 @@ mod tests {
43994414

44004415
// Verify output partitioning is Hash with num_output_partitions
44014416
assert_eq!(
4402-
aggregate_exec.properties().output_partitioning().partition_count(),
4417+
aggregate_exec
4418+
.properties()
4419+
.output_partitioning()
4420+
.partition_count(),
44034421
num_output_partitions
44044422
);
44054423

@@ -4408,8 +4426,7 @@ mod tests {
44084426
std::collections::HashMap::new();
44094427

44104428
for partition in 0..num_output_partitions {
4411-
let stream =
4412-
aggregate_exec.execute(partition, Arc::clone(&task_ctx))?;
4429+
let stream = aggregate_exec.execute(partition, Arc::clone(&task_ctx))?;
44134430
let batches: Vec<_> = stream
44144431
.collect::<Vec<_>>()
44154432
.await
@@ -4428,8 +4445,7 @@ mod tests {
44284445
.downcast_ref::<Float64Array>()
44294446
.unwrap();
44304447
for i in 0..batch.num_rows() {
4431-
*all_results.entry(groups.value(i)).or_insert(0.0) +=
4432-
sums.value(i);
4448+
*all_results.entry(groups.value(i)).or_insert(0.0) += sums.value(i);
44334449
}
44344450
}
44354451
}

0 commit comments

Comments
 (0)