@@ -634,8 +634,7 @@ enum AggRepartitionState {
634634 /// Input tasks have been spawned and are sending to output channels.
635635 Consuming {
636636 /// One receiver per output partition. `None` if already taken by `execute()`.
637- receivers :
638- Vec < Option < tokio:: sync:: mpsc:: UnboundedReceiver < Result < RecordBatch > > > > ,
637+ receivers : Vec < Option < tokio:: sync:: mpsc:: UnboundedReceiver < Result < RecordBatch > > > > ,
639638 /// Background tasks; dropped to abort if the exec is dropped.
640639 _abort_helper : Vec < SpawnedTask < ( ) > > ,
641640 } ,
@@ -745,7 +744,7 @@ impl AggregateExec {
745744
746745 /// Clone this exec, overriding only the limit hint.
747746 pub fn with_new_limit_options ( & self , limit_options : Option < LimitOptions > ) -> Self {
748- Self {
747+ let mut new = Self {
749748 limit_options,
750749 // clone the rest of the fields
751750 required_input_ordering : self . required_input_ordering . clone ( ) ,
@@ -762,7 +761,9 @@ impl AggregateExec {
762761 dynamic_filter : self . dynamic_filter . clone ( ) ,
763762 num_agg_partitions : self . num_agg_partitions ,
764763 repartition_state : Arc :: new ( Mutex :: new ( AggRepartitionState :: default ( ) ) ) ,
765- }
764+ } ;
765+ new. update_cache_partitioning ( ) ;
766+ new
766767 }
767768
768769 pub fn cache ( & self ) -> & PlanProperties {
@@ -910,27 +911,55 @@ impl AggregateExec {
910911 & self . mode
911912 }
912913
914+ /// Number of internal hash table partitions.
915+ pub fn num_agg_partitions ( & self ) -> usize {
916+ self . num_agg_partitions
917+ }
918+
913919 /// Set the number of internal hash table partitions for partial aggregation.
914920 /// When n > 1, the aggregate also acts as a repartitioning operator,
915921 /// producing n output partitions.
916922 pub fn with_num_agg_partitions ( mut self , n : usize ) -> Self {
917923 self . num_agg_partitions = n;
918- if n > 1 {
919- // Update output partitioning to reflect the repartitioned output.
920- let group_exprs = self . output_group_expr ( ) ;
921- let mut cache = ( * self . cache ) . clone ( ) ;
922- cache. partitioning = Partitioning :: Hash ( group_exprs, n) ;
923- self . cache = Arc :: new ( cache) ;
924- }
924+ self . update_cache_partitioning ( ) ;
925925 self
926926 }
927927
928928 /// Set the limit options for this AggExec
929929 pub fn with_limit_options ( mut self , limit_options : Option < LimitOptions > ) -> Self {
930930 self . limit_options = limit_options;
931+ self . update_cache_partitioning ( ) ;
931932 self
932933 }
933934
935+ /// Returns true if the channel-based multi-output path will be used
936+ /// in execute().
937+ fn use_channels ( & self ) -> bool {
938+ self . num_agg_partitions > 1
939+ && !self . group_by . is_true_no_grouping ( )
940+ && ( self . limit_options . is_none ( )
941+ || self . is_unordered_unfiltered_group_by_distinct ( ) )
942+ }
943+
944+ /// Update the cached output partitioning based on whether channels
945+ /// will be used.
946+ fn update_cache_partitioning ( & mut self ) {
947+ if self . use_channels ( ) {
948+ let group_exprs = self . output_group_expr ( ) ;
949+ let mut cache = ( * self . cache ) . clone ( ) ;
950+ cache. partitioning = Partitioning :: Hash ( group_exprs, self . num_agg_partitions ) ;
951+ self . cache = Arc :: new ( cache) ;
952+ } else if self . num_agg_partitions > 1 {
953+ // Channels won't be used (e.g. limit/TopK was set), revert
954+ // to single-partition output matching the non-channel path.
955+ let mut cache = ( * self . cache ) . clone ( ) ;
956+ cache. partitioning = Partitioning :: UnknownPartitioning (
957+ self . input . output_partitioning ( ) . partition_count ( ) ,
958+ ) ;
959+ self . cache = Arc :: new ( cache) ;
960+ }
961+ }
962+
934963 /// Get the limit options (if set)
935964 pub fn limit_options ( & self ) -> Option < LimitOptions > {
936965 self . limit_options
@@ -1028,8 +1057,7 @@ impl AggregateExec {
10281057 Err ( e) => {
10291058 let e = Arc :: new ( e) ;
10301059 for sender in & senders {
1031- let _ =
1032- sender. send ( Err ( DataFusionError :: from ( & e) ) ) ;
1060+ let _ = sender. send ( Err ( DataFusionError :: from ( & e) ) ) ;
10331061 }
10341062 return ;
10351063 }
@@ -1553,8 +1581,7 @@ impl ExecutionPlan for AggregateExec {
15531581 if self . num_agg_partitions > 1 {
15541582 let group_exprs = me. output_group_expr ( ) ;
15551583 let mut cache = ( * me. cache ) . clone ( ) ;
1556- cache. partitioning =
1557- Partitioning :: Hash ( group_exprs, self . num_agg_partitions ) ;
1584+ cache. partitioning = Partitioning :: Hash ( group_exprs, self . num_agg_partitions ) ;
15581585 me. cache = Arc :: new ( cache) ;
15591586 }
15601587
@@ -1566,7 +1593,7 @@ impl ExecutionPlan for AggregateExec {
15661593 partition : usize ,
15671594 context : Arc < TaskContext > ,
15681595 ) -> Result < SendableRecordBatchStream > {
1569- if self . num_agg_partitions <= 1 {
1596+ if ! self . use_channels ( ) {
15701597 return self
15711598 . execute_typed ( partition, & context)
15721599 . map ( |stream| stream. into ( ) ) ;
@@ -1593,18 +1620,12 @@ impl ExecutionPlan for AggregateExec {
15931620 // Spawn one task per input partition
15941621 let mut tasks = Vec :: with_capacity ( num_input) ;
15951622 for input_idx in 0 ..num_input {
1596- let stream = GroupedHashAggregateStream :: new (
1597- self , & context, input_idx,
1598- ) ?;
1599- let senders_clone: Vec < _ > =
1600- senders. iter ( ) . map ( |s| s. clone ( ) ) . collect ( ) ;
1623+ let stream = GroupedHashAggregateStream :: new ( self , & context, input_idx) ?;
1624+ let senders_clone: Vec < _ > = senders. iter ( ) . map ( |s| s. clone ( ) ) . collect ( ) ;
16011625 let n = num_output;
16021626
1603- let task = SpawnedTask :: spawn ( Self :: pull_and_route (
1604- stream,
1605- n,
1606- senders_clone,
1607- ) ) ;
1627+ let task =
1628+ SpawnedTask :: spawn ( Self :: pull_and_route ( stream, n, senders_clone) ) ;
16081629 tasks. push ( task) ;
16091630 }
16101631
@@ -1615,8 +1636,7 @@ impl ExecutionPlan for AggregateExec {
16151636 }
16161637
16171638 // Take this partition's receiver
1618- let AggRepartitionState :: Consuming { receivers, .. } = & mut * state
1619- else {
1639+ let AggRepartitionState :: Consuming { receivers, .. } = & mut * state else {
16201640 unreachable ! ( )
16211641 } ;
16221642
@@ -1625,12 +1645,17 @@ impl ExecutionPlan for AggregateExec {
16251645 . and_then ( |r| r. take ( ) )
16261646 . expect ( "partition receiver already consumed" ) ;
16271647 let schema = Arc :: clone ( & self . schema ) ;
1648+ // Keep a reference to the shared state so spawned tasks aren't
1649+ // aborted when the AggregateExec is dropped but streams are
1650+ // still being polled.
1651+ let _state_ref = Arc :: clone ( & self . repartition_state ) ;
16281652 drop ( state) ;
16291653
16301654 // Wrap the receiver as a RecordBatchStream
1631- let stream = futures:: stream:: unfold ( rx, |mut rx| async move {
1632- rx. recv ( ) . await . map ( |batch| ( batch, rx) )
1633- } ) ;
1655+ let stream =
1656+ futures:: stream:: unfold ( ( rx, _state_ref) , |( mut rx, state_ref) | async move {
1657+ rx. recv ( ) . await . map ( |batch| ( batch, ( rx, state_ref) ) )
1658+ } ) ;
16341659 Ok ( Box :: pin ( RecordBatchStreamAdapter :: new ( schema, stream) ) )
16351660 }
16361661
@@ -4348,9 +4373,7 @@ mod tests {
43484373 Arc :: clone ( & schema) ,
43494374 vec ! [
43504375 Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 , 1 , 2 , 3 ] ) ) ,
4351- Arc :: new( Float64Array :: from( vec![
4352- 10.0 , 20.0 , 30.0 , 40.0 , 50.0 , 60.0 ,
4353- ] ) ) ,
4376+ Arc :: new( Float64Array :: from( vec![ 10.0 , 20.0 , 30.0 , 40.0 , 50.0 , 60.0 ] ) ) ,
43544377 ] ,
43554378 ) ?;
43564379 let batch2 = RecordBatch :: try_new (
@@ -4364,25 +4387,17 @@ mod tests {
43644387
43654388 let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
43664389
4367- let group_expr =
4368- vec ! [ ( col( "group_col" , & schema) ?, "group_col" . to_string( ) ) ] ;
4390+ let group_expr = vec ! [ ( col( "group_col" , & schema) ?, "group_col" . to_string( ) ) ] ;
43694391 let aggr_expr = vec ! [ Arc :: new(
4370- AggregateExprBuilder :: new(
4371- sum_udaf( ) ,
4372- vec![ col( "value_col" , & schema) ?] ,
4373- )
4374- . schema( Arc :: clone( & schema) )
4375- . alias( "sum_value" )
4376- . build( ) ?,
4392+ AggregateExprBuilder :: new( sum_udaf( ) , vec![ col( "value_col" , & schema) ?] )
4393+ . schema( Arc :: clone( & schema) )
4394+ . alias( "sum_value" )
4395+ . build( ) ?,
43774396 ) ] ;
43784397
43794398 let num_output_partitions = 3 ;
43804399
4381- let exec = TestMemoryExec :: try_new (
4382- & input_partitions,
4383- Arc :: clone ( & schema) ,
4384- None ,
4385- ) ?;
4400+ let exec = TestMemoryExec :: try_new ( & input_partitions, Arc :: clone ( & schema) , None ) ?;
43864401 let exec = Arc :: new ( TestMemoryExec :: update_cache ( & Arc :: new ( exec) ) ) ;
43874402
43884403 let aggregate_exec = Arc :: new (
@@ -4399,7 +4414,10 @@ mod tests {
43994414
44004415 // Verify output partitioning is Hash with num_output_partitions
44014416 assert_eq ! (
4402- aggregate_exec. properties( ) . output_partitioning( ) . partition_count( ) ,
4417+ aggregate_exec
4418+ . properties( )
4419+ . output_partitioning( )
4420+ . partition_count( ) ,
44034421 num_output_partitions
44044422 ) ;
44054423
@@ -4408,8 +4426,7 @@ mod tests {
44084426 std:: collections:: HashMap :: new ( ) ;
44094427
44104428 for partition in 0 ..num_output_partitions {
4411- let stream =
4412- aggregate_exec. execute ( partition, Arc :: clone ( & task_ctx) ) ?;
4429+ let stream = aggregate_exec. execute ( partition, Arc :: clone ( & task_ctx) ) ?;
44134430 let batches: Vec < _ > = stream
44144431 . collect :: < Vec < _ > > ( )
44154432 . await
@@ -4428,8 +4445,7 @@ mod tests {
44284445 . downcast_ref :: < Float64Array > ( )
44294446 . unwrap ( ) ;
44304447 for i in 0 ..batch. num_rows ( ) {
4431- * all_results. entry ( groups. value ( i) ) . or_insert ( 0.0 ) +=
4432- sums. value ( i) ;
4448+ * all_results. entry ( groups. value ( i) ) . or_insert ( 0.0 ) += sums. value ( i) ;
44334449 }
44344450 }
44354451 }
0 commit comments