From 40463ea8e0f7d1960e6d26ba985c2eda37048b35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 6 Mar 2026 20:02:53 +0100 Subject: [PATCH 1/6] Test early emitting --- .../physical-plan/src/aggregates/row_hash.rs | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 4a1b0e5c8c027..acb8ac7afaa59 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -776,6 +776,25 @@ impl Stream for GroupedHashAggregateStream { self.exec_state = new_state; break 'reading_input; } + + // Emit partial aggregation results once accumulated + // state exceeds 4MB to bound memory usage and provide + // incremental output to downstream operators. + if self.current_aggregate_size() + >= Self::PARTIAL_AGGREGATE_EMIT_SIZE + && !self.group_values.is_empty() + { + timer.done(); + if let Some(batch) = self.emit(EmitTo::All, false)? { + // Clear the hash map so new groups can be + // interned correctly when we resume reading. + let batch_size = self.batch_size; + self.clear_shrink(batch_size); + self.exec_state = + ExecutionState::ProducingOutput(batch); + } + break 'reading_input; + } } // If we reach this point, try to update the memory reservation @@ -1210,6 +1229,18 @@ impl GroupedHashAggregateStream { self.clear_shrink(0); } + /// Size threshold (in bytes) at which partial aggregation emits accumulated + /// state. Emitting periodically bounds memory usage and provides incremental + /// output to downstream operators (e.g. repartition / final aggregation). + const PARTIAL_AGGREGATE_EMIT_SIZE: usize = 4 * 1024 * 1024; // 4 MB + + /// Returns the current memory size of accumulated group state + /// (group values + accumulators). + fn current_aggregate_size(&self) -> usize { + let acc: usize = self.accumulators.iter().map(|a| a.size()).sum(); + acc + self.group_values.size() + } + /// returns true if there is a soft groups limit and the number of distinct /// groups we have seen is over that limit fn hit_soft_group_limit(&self) -> bool { From b793a938a053788e545899c4676981497c0ebab8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 6 Mar 2026 21:36:15 +0100 Subject: [PATCH 2/6] Add per-partition hash tables for partial aggregation Introduce PartitionAggState to support multiple internal hash tables in partial aggregation. When enabled via AggregateExec::with_num_agg_partitions(), input rows are hashed by group keys (using the same hash as RepartitionExec) and routed to separate smaller hash tables for better cache locality. Defaults to 1 partition (no behavior change). The optimizer can set higher values when a hash repartition follows the partial aggregate. Co-Authored-By: Claude Opus 4.6 --- .../physical-plan/src/aggregates/mod.rs | 15 + .../physical-plan/src/aggregates/row_hash.rs | 696 +++++++++++++----- 2 files changed, 533 insertions(+), 178 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 561e9d1d05e89..84546e0407ecb 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -660,6 +660,12 @@ pub struct AggregateExec { /// it remains `Some(..)` to enable dynamic filtering during aggregate execution; /// otherwise, it is cleared to `None`. dynamic_filter: Option>, + + /// Number of internal hash table partitions for partial aggregation. + /// When > 1, input rows are hashed by group keys and routed to separate + /// smaller hash tables for better cache locality. Only used when mode is + /// Partial and input is unordered. Defaults to 1 (single hash table). + pub(crate) num_agg_partitions: usize, } impl AggregateExec { @@ -685,6 +691,7 @@ impl AggregateExec { schema: Arc::clone(&self.schema), input_schema: Arc::clone(&self.input_schema), dynamic_filter: self.dynamic_filter.clone(), + num_agg_partitions: self.num_agg_partitions, } } @@ -705,6 +712,7 @@ impl AggregateExec { schema: Arc::clone(&self.schema), input_schema: Arc::clone(&self.input_schema), dynamic_filter: self.dynamic_filter.clone(), + num_agg_partitions: self.num_agg_partitions, } } @@ -839,6 +847,7 @@ impl AggregateExec { input_order_mode, cache: Arc::new(cache), dynamic_filter: None, + num_agg_partitions: 1, }; exec.init_dynamic_filter(); @@ -851,6 +860,12 @@ impl AggregateExec { &self.mode } + /// Set the number of internal hash table partitions for partial aggregation. + pub fn with_num_agg_partitions(mut self, n: usize) -> Self { + self.num_agg_partitions = n; + self + } + /// Set the limit options for this AggExec pub fn with_limit_options(mut self, limit_options: Option) -> Self { self.limit_options = limit_options; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index acb8ac7afaa59..56dbcf5d882b8 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -50,7 +50,10 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use crate::repartition::REPARTITION_RANDOM_STATE; use crate::sorts::IncrementalSortIterator; +use arrow::compute::take_arrays; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::instant::Instant; use datafusion_common::utils::memory::get_record_batch_memory_size; use futures::ready; @@ -221,6 +224,16 @@ enum OutOfMemoryMode { ReportError, } +/// Per-partition aggregation state for partitioned hash aggregation. +/// When the partial aggregate uses multiple partitions, input rows are +/// hashed by group keys and routed to the appropriate partition's hash +/// table for better cache locality. +struct PartitionAggState { + group_values: Box, + accumulators: Vec>, + current_group_indices: Vec, +} + /// HashTable based Grouping Aggregator /// /// # Design Goals @@ -407,19 +420,19 @@ pub(crate) struct GroupedHashAggregateStream { // STATE BUFFERS: // These fields will accumulate intermediate results during the execution. // ======================================================================== - /// An interning store of group keys - group_values: Box, + /// Per-partition aggregation state. When `partitions.len() > 1`, + /// input rows are hashed by group keys and routed to the + /// appropriate partition for better cache locality. + partitions: Vec, - /// scratch space for the current input [`RecordBatch`] being - /// processed. Reused across batches here to avoid reallocations - current_group_indices: Vec, + /// Scratch buffer for hashing group keys to determine partition assignment. + hash_buffer: Vec, - /// Accumulators, one for each `AggregateFunctionExpr` in the query - /// - /// For example, if the query has aggregates, `SUM(x)`, - /// `COUNT(y)`, there will be two accumulators, each one - /// specialized for that particular aggregate and its input types - accumulators: Vec>, + /// Per-partition row indices used during input routing. + partition_indices: Vec>, + + /// Tracks which partition to emit next during final output drain. + emit_partition_idx: usize, // ======================================================================== // TASK-SPECIFIC STATES: @@ -497,12 +510,6 @@ impl GroupedHashAggregateStream { AggregateInputMode::Partial => vec![None; agg.aggr_expr.len()].into(), }; - // Instantiate the accumulators - let accumulators: Vec<_> = aggregate_exprs - .iter() - .map(create_group_accumulator) - .collect::>()?; - let group_schema = agg_group_by.group_schema(&agg.input().schema())?; // fix https://github.com/apache/datafusion/issues/13949 @@ -587,7 +594,6 @@ impl GroupedHashAggregateStream { _ => OutOfMemoryMode::ReportError, }; - let group_values = new_group_values(group_schema, &group_ordering)?; let reservation = MemoryConsumer::new(name) // We interpret 'can spill' as 'can handle memory back pressure'. // This value needs to be set to true for the default memory pool implementations @@ -617,6 +623,32 @@ impl GroupedHashAggregateStream { spill_manager, }; + // Determine the number of internal aggregation partitions. + // When > 1, input rows are hashed by group keys and routed to + // separate smaller hash tables for better cache locality. + // Currently defaults to 1; set via AggregateExec::num_agg_partitions. + let num_agg_partitions = if agg.mode == AggregateMode::Partial + && matches!(group_ordering, GroupOrdering::None) + { + agg.num_agg_partitions.max(1) + } else { + 1 + }; + + let mut partitions = Vec::with_capacity(num_agg_partitions); + for _ in 0..num_agg_partitions { + let gv = new_group_values(Arc::clone(&group_schema), &group_ordering)?; + let accs: Vec<_> = aggregate_exprs + .iter() + .map(create_group_accumulator) + .collect::>()?; + partitions.push(PartitionAggState { + group_values: gv, + accumulators: accs, + current_group_indices: Default::default(), + }); + } + // Skip aggregation is supported if: // - aggregation mode is Partial // - input is not ordered by GROUP BY expressions, @@ -626,7 +658,8 @@ impl GroupedHashAggregateStream { // - there is only one GROUP BY expressions set let skip_aggregation_probe = if agg.mode == AggregateMode::Partial && matches!(group_ordering, GroupOrdering::None) - && accumulators + && partitions[0] + .accumulators .iter() .all(|acc| acc.supports_convert_to_state()) && agg_group_by.is_single() @@ -661,14 +694,19 @@ impl GroupedHashAggregateStream { schema: agg_schema, input, mode: agg.mode, - accumulators, + partitions, + hash_buffer: Vec::new(), + partition_indices: if num_agg_partitions > 1 { + vec![vec![]; num_agg_partitions] + } else { + Vec::new() + }, + emit_partition_idx: num_agg_partitions, aggregate_arguments, filter_expressions, group_by: agg_group_by, reservation, oom_mode, - group_values, - current_group_indices: Default::default(), exec_state, baseline_metrics, group_by_metrics, @@ -747,13 +785,15 @@ impl Stream for GroupedHashAggregateStream { // Try to emit completed groups if possible. // If we already started spilling, we can no longer emit since - // this might lead to incorrect output ordering - if (self.spill_state.spills.is_empty() - || self.spill_state.is_stream_merging) + // this might lead to incorrect output ordering. + // Group ordering only applies to single-partition mode. + if self.partitions.len() == 1 + && (self.spill_state.spills.is_empty() + || self.spill_state.is_stream_merging) && let Some(to_emit) = self.group_ordering.emit_to() { timer.done(); - if let Some(batch) = self.emit(to_emit, false)? { + if let Some(batch) = self.emit(0, to_emit, false)? { self.exec_state = ExecutionState::ProducingOutput(batch); }; @@ -777,24 +817,6 @@ impl Stream for GroupedHashAggregateStream { break 'reading_input; } - // Emit partial aggregation results once accumulated - // state exceeds 4MB to bound memory usage and provide - // incremental output to downstream operators. - if self.current_aggregate_size() - >= Self::PARTIAL_AGGREGATE_EMIT_SIZE - && !self.group_values.is_empty() - { - timer.done(); - if let Some(batch) = self.emit(EmitTo::All, false)? { - // Clear the hash map so new groups can be - // interned correctly when we resume reading. - let batch_size = self.batch_size; - self.clear_shrink(batch_size); - self.exec_state = - ExecutionState::ProducingOutput(batch); - } - break 'reading_input; - } } // If we reach this point, try to update the memory reservation @@ -844,11 +866,11 @@ impl Stream for GroupedHashAggregateStream { // inner is done, switching to `Done` state // Sanity check: when switching from SkippingAggregation to Done, // all groups should have already been emitted - if !self.group_values.is_empty() { + if !self.all_partitions_empty() { return Poll::Ready(Some(internal_err!( "Switching from SkippingAggregation to Done with {} groups still in hash table. \ This is a bug - all groups should have been emitted before skip aggregation started.", - self.group_values.len() + self.total_group_len() ))); } self.exec_state = ExecutionState::Done; @@ -856,13 +878,25 @@ impl Stream for GroupedHashAggregateStream { } } - ExecutionState::ProducingOutput(batch) => { + ExecutionState::ProducingOutput(_) => { + // Take ownership of the batch to allow mutable self access below. + let batch = std::mem::replace( + &mut self.exec_state, + ExecutionState::Done, + ); + let ExecutionState::ProducingOutput(batch) = batch else { + unreachable!() + }; // slice off a part of the batch, if needed let output_batch; let size = self.batch_size; (self.exec_state, output_batch) = if batch.num_rows() <= size { - ( - if self.input_done { + // If there are remaining partitions to drain, emit them + // before transitioning to the next state. + let next_state = + if self.emit_partition_idx < self.partitions.len() { + self.emit_next_partition()? + } else if self.input_done { ExecutionState::Done } // In Partial aggregation, we also need to check @@ -873,9 +907,8 @@ impl Stream for GroupedHashAggregateStream { ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput - }, - batch.clone(), - ) + }; + (next_state, batch) } else { // output first batch_size rows let size = self.batch_size; @@ -899,11 +932,11 @@ impl Stream for GroupedHashAggregateStream { ExecutionState::Done => { // Sanity check: all groups should have been emitted by now - if !self.group_values.is_empty() { + if !self.all_partitions_empty() { return Poll::Ready(Some(internal_err!( "AggregateStream was in Done state with {} groups left in hash table. \ This is a bug - all groups should have been emitted before entering Done state.", - self.group_values.len() + self.total_group_len() ))); } // release the memory reservation since sending back output batch itself needs @@ -955,68 +988,127 @@ impl GroupedHashAggregateStream { // Evaluate the filter expressions, if any, against the inputs let filter_values = if self.spill_state.is_stream_merging { - let filter_expressions = vec![None; self.accumulators.len()]; + let filter_expressions = + vec![None; self.partitions[0].accumulators.len()]; evaluate_optional(&filter_expressions, batch)? } else { evaluate_optional(&self.filter_expressions, batch)? }; + let is_raw_input = self.mode.input_mode() == AggregateInputMode::Raw + && !self.spill_state.is_stream_merging; + for group_values in &group_by_values { let groups_start_time = Instant::now(); - // calculate the group indices for each input row - let starting_num_groups = self.group_values.len(); - self.group_values - .intern(group_values, &mut self.current_group_indices)?; - let group_indices = &self.current_group_indices; + if self.partitions.len() == 1 { + // Fast path: single partition, no hashing needed + let partition = &mut self.partitions[0]; + let starting_num_groups = partition.group_values.len(); + partition + .group_values + .intern(group_values, &mut partition.current_group_indices)?; + let total_num_groups = partition.group_values.len(); + + if total_num_groups > starting_num_groups { + self.group_ordering.new_groups( + group_values, + &partition.current_group_indices, + total_num_groups, + )?; + } - // Update ordering information if necessary - let total_num_groups = self.group_values.len(); - if total_num_groups > starting_num_groups { - self.group_ordering.new_groups( + let agg_start_time = Instant::now(); + self.group_by_metrics + .time_calculating_group_ids + .add_duration(agg_start_time - groups_start_time); + + Self::accumulate_partition( + partition, + &input_values, + &filter_values, + is_raw_input, + )?; + + self.group_by_metrics + .aggregation_time + .add_elapsed(agg_start_time); + } else { + // Multi-partition path: hash group keys and route rows + let num_rows = group_values[0].len(); + self.hash_buffer.clear(); + self.hash_buffer.resize(num_rows, 0); + create_hashes( group_values, - group_indices, - total_num_groups, + REPARTITION_RANDOM_STATE.random_state(), + &mut self.hash_buffer, )?; - } - // Use this instant for both measurements to save a syscall - let agg_start_time = Instant::now(); - self.group_by_metrics - .time_calculating_group_ids - .add_duration(agg_start_time - groups_start_time); + let num_partitions = self.partitions.len(); + self.partition_indices + .iter_mut() + .for_each(|v| v.clear()); + for (row_idx, hash) in self.hash_buffer.iter().enumerate() { + self.partition_indices + [(*hash % num_partitions as u64) as usize] + .push(row_idx as u32); + } - // Gather the inputs to call the actual accumulator - let t = self - .accumulators - .iter_mut() - .zip(input_values.iter()) - .zip(filter_values.iter()); - - for ((acc, values), opt_filter) in t { - let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); - - // Call the appropriate method on each aggregator with - // the entire input row and the relevant group indexes - if self.mode.input_mode() == AggregateInputMode::Raw - && !self.spill_state.is_stream_merging - { - acc.update_batch( - values, - group_indices, - opt_filter, - total_num_groups, - )?; - } else { - assert_or_internal_err!( - opt_filter.is_none(), - "aggregate filter should be applied in partial stage, there should be no filter in final stage" + let agg_start_time = Instant::now(); + self.group_by_metrics + .time_calculating_group_ids + .add_duration(agg_start_time - groups_start_time); + + for p in 0..num_partitions { + if self.partition_indices[p].is_empty() { + continue; + } + let indices_array = UInt32Array::from_iter_values( + self.partition_indices[p].iter().copied(), ); - // if aggregation is over intermediate states, - // use merge - acc.merge_batch(values, group_indices, None, total_num_groups)?; + // Take sub-batch for this partition + let part_groups = + take_arrays(group_values, &indices_array, None)?; + let part_inputs: Vec> = input_values + .iter() + .map(|vals| { + take_arrays(vals.as_slice(), &indices_array, None) + .map_err(DataFusionError::from) + }) + .collect::>()?; + let part_filters: Vec> = filter_values + .iter() + .map(|opt_f| { + opt_f + .as_ref() + .map(|f| { + take_arrays( + std::slice::from_ref(f), + &indices_array, + None, + ) + .map(|mut v| v.remove(0)) + .map_err(DataFusionError::from) + }) + .transpose() + }) + .collect::>()?; + + let partition = &mut self.partitions[p]; + partition.group_values.intern( + &part_groups, + &mut partition.current_group_indices, + )?; + + Self::accumulate_partition( + partition, + &part_inputs, + &part_filters, + is_raw_input, + )?; } + self.group_by_metrics .aggregation_time .add_elapsed(agg_start_time); @@ -1026,6 +1118,48 @@ impl GroupedHashAggregateStream { Ok(()) } + /// Update accumulators for a single partition. + fn accumulate_partition( + partition: &mut PartitionAggState, + input_values: &[Vec], + filter_values: &[Option], + is_raw_input: bool, + ) -> Result<()> { + let group_indices = &partition.current_group_indices; + let total_num_groups = partition.group_values.len(); + + let t = partition + .accumulators + .iter_mut() + .zip(input_values.iter()) + .zip(filter_values.iter()); + + for ((acc, values), opt_filter) in t { + let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); + + if is_raw_input { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } else { + assert_or_internal_err!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); + acc.merge_batch( + values, + group_indices, + None, + total_num_groups, + )?; + } + } + Ok(()) + } + /// Attempts to update the memory reservation. If that fails due to a /// [DataFusionError::ResourcesExhausted] error, an attempt will be made to resolve /// the out-of-memory condition based on the [out-of-memory handling mode](OutOfMemoryMode). @@ -1041,37 +1175,52 @@ impl GroupedHashAggregateStream { }; match self.oom_mode { - OutOfMemoryMode::Spill if !self.group_values.is_empty() => { + OutOfMemoryMode::Spill if !self.all_partitions_empty() => { self.spill()?; self.clear_shrink(self.batch_size); self.update_memory_reservation()?; Ok(None) } - OutOfMemoryMode::EmitEarly if self.group_values.len() > 1 => { - let n = if self.group_values.len() >= self.batch_size { - // Try to emit an integer multiple of batch size if possible - self.group_values.len() / self.batch_size * self.batch_size - } else { - // Otherwise emit whatever we can - self.group_values.len() - }; - - // Clamp to the sort boundary when using partial group ordering, - // otherwise remove_groups panics (#20445). - let n = match &self.group_ordering { - GroupOrdering::None => n, - _ => match self.group_ordering.emit_to() { - Some(EmitTo::First(max)) => n.min(max), - _ => 0, - }, - }; - - if n > 0 - && let Some(batch) = self.emit(EmitTo::First(n), false)? - { - Ok(Some(ExecutionState::ProducingOutput(batch))) + OutOfMemoryMode::EmitEarly if self.total_group_len() > 1 => { + if self.partitions.len() == 1 { + let n = if self.partitions[0].group_values.len() + >= self.batch_size + { + self.partitions[0].group_values.len() / self.batch_size + * self.batch_size + } else { + self.partitions[0].group_values.len() + }; + + // Clamp to the sort boundary when using partial group ordering, + // otherwise remove_groups panics (#20445). + let n = match &self.group_ordering { + GroupOrdering::None => n, + _ => match self.group_ordering.emit_to() { + Some(EmitTo::First(max)) => n.min(max), + _ => 0, + }, + }; + + if n > 0 + && let Some(batch) = + self.emit(0, EmitTo::First(n), false)? + { + Ok(Some(ExecutionState::ProducingOutput(batch))) + } else { + Err(oom) + } } else { - Err(oom) + // Multi-partition: emit the largest partition + let largest = self.largest_partition_idx(); + if let Some(batch) = + self.emit(largest, EmitTo::All, false)? + { + self.clear_partition(largest, self.batch_size); + Ok(Some(ExecutionState::ProducingOutput(batch))) + } else { + Err(oom) + } } } _ => Err(oom), @@ -1079,26 +1228,25 @@ impl GroupedHashAggregateStream { } fn update_memory_reservation(&mut self) -> Result<()> { - let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - let groups_and_acc_size = acc - + self.group_values.size() - + self.group_ordering.size() - + self.current_group_indices.allocated_size(); + let mut acc: usize = 0; + let mut groups_size: usize = 0; + let mut indices_size: usize = 0; + for p in &self.partitions { + acc += p.accumulators.iter().map(|x| x.size()).sum::(); + groups_size += p.group_values.size(); + indices_size += p.current_group_indices.allocated_size(); + } + let groups_and_acc_size = + acc + groups_size + self.group_ordering.size() + indices_size; // Reserve extra headroom for sorting during potential spill. - // When OOM triggers, group_aggregate_batch has already processed the - // latest input batch, so the internal state may have grown well beyond - // the last successful reservation. The emit batch reflects this larger - // actual state, and the sort needs memory proportional to it. - // By reserving headroom equal to the data size, we trigger OOM earlier - // (before too much data accumulates), ensuring the freed reservation - // after clear_shrink is sufficient to cover the sort memory. - let sort_headroom = - if self.oom_mode == OutOfMemoryMode::Spill && !self.group_values.is_empty() { - acc + self.group_values.size() - } else { - 0 - }; + let sort_headroom = if self.oom_mode == OutOfMemoryMode::Spill + && !self.all_partitions_empty() + { + acc + groups_size + } else { + 0 + }; let new_size = groups_and_acc_size + sort_headroom; let reservation_result = self.reservation.try_resize(new_size); @@ -1114,36 +1262,45 @@ impl GroupedHashAggregateStream { /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to - fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result> { + fn emit( + &mut self, + partition_idx: usize, + emit_to: EmitTo, + spilling: bool, + ) -> Result> { let schema = if spilling { Arc::clone(&self.spill_state.spill_schema) } else { self.schema() }; - if self.group_values.is_empty() { + let mode = self.mode; + let num_partitions = self.partitions.len(); + + if self.partitions[partition_idx].group_values.is_empty() { return Ok(None); } let timer = self.group_by_metrics.emitting_time.timer(); - let mut output = self.group_values.emit(emit_to)?; + let partition = &mut self.partitions[partition_idx]; + let mut output = partition.group_values.emit(emit_to)?; + if let EmitTo::First(n) = emit_to { - self.group_ordering.remove_groups(n); + if num_partitions == 1 { + self.group_ordering.remove_groups(n); + } } // Next output each aggregate value - for acc in self.accumulators.iter_mut() { - if self.mode.output_mode() == AggregateOutputMode::Final && !spilling { + for acc in partition.accumulators.iter_mut() { + if mode.output_mode() == AggregateOutputMode::Final && !spilling { output.push(acc.evaluate(emit_to)?) } else { - // Output partial state: either because we're in a non-final mode, - // or because we're spilling and will merge/re-evaluate later. output.extend(acc.state(emit_to)?) } } drop(timer); - // emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is - // over the target memory size after emission, we can emit again rather than returning Err. + // emit reduces the memory usage. Ignore Err from update_memory_reservation. let _ = self.update_memory_reservation(); let batch = RecordBatch::try_new(schema, output)?; debug_assert!(batch.num_rows() > 0); @@ -1155,8 +1312,8 @@ impl GroupedHashAggregateStream { /// This process helps in reducing memory pressure by allowing the data to be /// read back with streaming merge. fn spill(&mut self) -> Result<()> { - // Emit and sort intermediate aggregation state - let Some(emit) = self.emit(EmitTo::All, true)? else { + // Emit and sort intermediate aggregation state (spill always uses partition 0) + let Some(emit) = self.emit(0, EmitTo::All, true)? else { return Ok(()); }; @@ -1219,9 +1376,19 @@ impl GroupedHashAggregateStream { /// Clear memory and shrink capacities to the given number of rows. fn clear_shrink(&mut self, num_rows: usize) { - self.group_values.clear_shrink(num_rows); - self.current_group_indices.clear(); - self.current_group_indices.shrink_to(num_rows); + for p in &mut self.partitions { + p.group_values.clear_shrink(num_rows); + p.current_group_indices.clear(); + p.current_group_indices.shrink_to(num_rows); + } + } + + /// Clear a single partition's memory and shrink its capacities. + fn clear_partition(&mut self, partition_idx: usize, num_rows: usize) { + let p = &mut self.partitions[partition_idx]; + p.group_values.clear_shrink(num_rows); + p.current_group_indices.clear(); + p.current_group_indices.shrink_to(num_rows); } /// Clear memory and shrink capacities to zero. @@ -1229,16 +1396,43 @@ impl GroupedHashAggregateStream { self.clear_shrink(0); } - /// Size threshold (in bytes) at which partial aggregation emits accumulated - /// state. Emitting periodically bounds memory usage and provides incremental - /// output to downstream operators (e.g. repartition / final aggregation). - const PARTIAL_AGGREGATE_EMIT_SIZE: usize = 4 * 1024 * 1024; // 4 MB + /// Total number of groups across all partitions. + fn total_group_len(&self) -> usize { + self.partitions.iter().map(|p| p.group_values.len()).sum() + } + + /// Returns true if all partitions have no groups. + fn all_partitions_empty(&self) -> bool { + self.partitions.iter().all(|p| p.group_values.is_empty()) + } + + /// Returns the index of the partition with the most accumulated data. + fn largest_partition_idx(&self) -> usize { + self.partitions + .iter() + .enumerate() + .max_by_key(|(_, p)| { + p.accumulators + .iter() + .map(|a| a.size()) + .sum::() + + p.group_values.size() + }) + .map(|(i, _)| i) + .unwrap_or(0) + } - /// Returns the current memory size of accumulated group state - /// (group values + accumulators). - fn current_aggregate_size(&self) -> usize { - let acc: usize = self.accumulators.iter().map(|a| a.size()).sum(); - acc + self.group_values.size() + /// Emit the next non-empty partition (used during final output drain). + /// Returns the new execution state. + fn emit_next_partition(&mut self) -> Result { + while self.emit_partition_idx < self.partitions.len() { + let idx = self.emit_partition_idx; + self.emit_partition_idx += 1; + if let Some(batch) = self.emit(idx, EmitTo::All, false)? { + return Ok(ExecutionState::ProducingOutput(batch)); + } + } + Ok(ExecutionState::Done) } /// returns true if there is a soft groups limit and the number of distinct @@ -1247,7 +1441,7 @@ impl GroupedHashAggregateStream { let Some(group_values_soft_limit) = self.group_values_soft_limit else { return false; }; - group_values_soft_limit <= self.group_values.len() + group_values_soft_limit <= self.total_group_len() } /// Finalizes reading of the input stream and prepares for producing output values. @@ -1261,12 +1455,15 @@ impl GroupedHashAggregateStream { let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { // Input has been entirely processed without spilling to disk. - - // Flush any remaining group values. - let batch = self.emit(EmitTo::All, false)?; - - // If there are none, we're done; otherwise switch to emitting them - batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput) + if self.partitions.len() == 1 { + // Single partition: flush remaining group values. + let batch = self.emit(0, EmitTo::All, false)?; + batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput) + } else { + // Multi-partition: emit partitions one at a time. + self.emit_partition_idx = 0; + self.emit_next_partition()? + } } else { // Spill any remaining data to disk. There is some performance overhead in // writing out this last chunk of data and reading it back. The benefit of @@ -1318,7 +1515,9 @@ impl GroupedHashAggregateStream { // Skip aggregation probe is not supported if stream has any spills, // currently spilling is not supported for Partial aggregation assert!(self.spill_state.spills.is_empty()); - probe.update_state(input_rows, self.group_values.len()); + let total_groups: usize = + self.partitions.iter().map(|p| p.group_values.len()).sum(); + probe.update_state(input_rows, total_groups); }; } @@ -1331,9 +1530,15 @@ impl GroupedHashAggregateStream { fn switch_to_skip_aggregation(&mut self) -> Result> { if let Some(probe) = self.skip_aggregation_probe.as_mut() && probe.should_skip() - && let Some(batch) = self.emit(EmitTo::All, false)? { - return Ok(Some(ExecutionState::ProducingOutput(batch))); + // Emit partitions one at a time. The first non-empty partition + // becomes ProducingOutput; remaining partitions drain via + // emit_next_partition() when ProducingOutput completes. + self.emit_partition_idx = 0; + let state = self.emit_next_partition()?; + if !matches!(state, ExecutionState::Done) { + return Ok(Some(state)); + } }; Ok(None) @@ -1362,7 +1567,7 @@ impl GroupedHashAggregateStream { ); let mut output = group_values.swap_remove(0); - let iter = self + let iter = self.partitions[0] .accumulators .iter() .zip(input_values.iter()) @@ -1729,4 +1934,139 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_partitioned_hash_aggregation() -> Result<()> { + // Test that multi-partition hash aggregation produces correct results. + // With num_agg_partitions > 1, rows are hashed by group key and routed + // to separate hash tables. The final results should be identical to + // single-partition aggregation. + + let schema = Arc::new(Schema::new(vec![ + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Int64, false), + ])); + + // Create data with known group distribution + let num_rows = 1000; + let num_groups = 50; + let group_ids: Vec = + (0..num_rows).map(|i| (i % num_groups) as i32).collect(); + let values: Vec = (0..num_rows).map(|i| i as i64).collect(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(group_ids)), + Arc::new(Int64Array::from(values)), + ], + )?; + + let input_partitions = vec![vec![batch]]; + let task_ctx = Arc::new(TaskContext::default()); + + let group_expr = + vec![(col("group_col", &schema)?, "group_col".to_string())]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new( + count_udaf(), + vec![col("value_col", &schema)?], + ) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, + )]; + + // Run with 1 partition (baseline) + let exec = TestMemoryExec::try_new( + &input_partitions, + Arc::clone(&schema), + None, + )?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec_1 = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr.clone()), + aggr_expr.clone(), + vec![None], + Arc::clone(&exec) as _, + Arc::clone(&schema), + )?; + + let mut stream_1 = GroupedHashAggregateStream::new( + &aggregate_exec_1, + &task_ctx, + 0, + )?; + let mut results_1 = std::collections::HashMap::new(); + while let Some(Ok(batch)) = stream_1.next().await { + let groups = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let counts = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + *results_1.entry(groups.value(i)).or_insert(0i64) += + counts.value(i); + } + } + + // Run with 4 partitions + let exec = TestMemoryExec::try_new( + &input_partitions, + Arc::clone(&schema), + None, + )?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec_4 = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )? + .with_num_agg_partitions(4); + + let mut stream_4 = GroupedHashAggregateStream::new( + &aggregate_exec_4, + &task_ctx, + 0, + )?; + let mut results_4 = std::collections::HashMap::new(); + while let Some(Ok(batch)) = stream_4.next().await { + let groups = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let counts = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + *results_4.entry(groups.value(i)).or_insert(0i64) += + counts.value(i); + } + } + + // Verify results match + assert_eq!(results_1.len(), num_groups as usize); + assert_eq!(results_1, results_4); + + // Each group should have exactly num_rows/num_groups = 20 rows + for (_, count) in &results_1 { + assert_eq!(*count, (num_rows / num_groups) as i64); + } + + Ok(()) + } } From 27fb7c8efd1dbcb9f50a6a0bb89d75a9e10f9bcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 6 Mar 2026 22:09:32 +0100 Subject: [PATCH 3/6] Add channel-based multi-output to AggregateExec for repartitioned partial aggregation When num_agg_partitions > 1, the partial aggregate now acts as a repartitioning operator, producing T output partitions directly via channels. Each input task runs a GroupedHashAggregateStream with T internal hash tables, then routes emitted batches to the correct output channel using last_emitted_partition (no re-hashing needed since internal tables use the same REPARTITION_RANDOM_STATE hash). Co-Authored-By: Claude Opus 4.6 --- .../physical-plan/src/aggregates/mod.rs | 296 +++++++++++++++++- .../physical-plan/src/aggregates/row_hash.rs | 8 + .../physical-plan/src/repartition/mod.rs | 2 +- 3 files changed, 301 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 84546e0407ecb..32b7477fdda1b 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -31,12 +31,15 @@ use crate::filter_pushdown::{ FilterPushdownPropagation, PushedDownPredicate, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::stream::RecordBatchStreamAdapter; use crate::{ - DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, + DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, SendableRecordBatchStream, Statistics, check_if_same_properties, }; use datafusion_common::config::ConfigOptions; +use datafusion_common_runtime::SpawnedTask; use datafusion_physical_expr::utils::collect_columns; +use futures::StreamExt; use parking_lot::Mutex; use std::collections::HashSet; @@ -47,7 +50,8 @@ use arrow_schema::FieldRef; use datafusion_common::stats::Precision; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{ - Constraint, Constraints, Result, ScalarValue, assert_eq_or_internal_err, not_impl_err, + Constraint, Constraints, DataFusionError, Result, ScalarValue, + assert_eq_or_internal_err, not_impl_err, }; use datafusion_execution::TaskContext; use datafusion_expr::{Accumulator, Aggregate}; @@ -621,6 +625,39 @@ impl LimitOptions { } } +/// State for channel-based multi-output aggregation. +/// When `num_agg_partitions > 1`, partial aggregation acts as a repartitioning +/// operator, routing output from N input partitions to T output partitions. +enum AggRepartitionState { + /// Not yet initialized. Transitions on first `execute()` call. + NotInitialized, + /// Input tasks have been spawned and are sending to output channels. + Consuming { + /// One receiver per output partition. `None` if already taken by `execute()`. + receivers: + Vec>>>, + /// Background tasks; dropped to abort if the exec is dropped. + _abort_helper: Vec>, + }, +} + +impl Default for AggRepartitionState { + fn default() -> Self { + Self::NotInitialized + } +} + +impl std::fmt::Debug for AggRepartitionState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::NotInitialized => write!(f, "NotInitialized"), + Self::Consuming { receivers, .. } => { + write!(f, "Consuming({})", receivers.len()) + } + } + } +} + /// Hash aggregate execution plan #[derive(Debug, Clone)] pub struct AggregateExec { @@ -665,7 +702,17 @@ pub struct AggregateExec { /// When > 1, input rows are hashed by group keys and routed to separate /// smaller hash tables for better cache locality. Only used when mode is /// Partial and input is unordered. Defaults to 1 (single hash table). + /// + /// When > 1 AND this is a Partial aggregate, this also acts as a + /// repartitioning operator: the aggregate produces `num_agg_partitions` + /// output partitions (one per internal hash table), eliminating the need + /// for a downstream RepartitionExec. pub(crate) num_agg_partitions: usize, + + /// Shared state for channel-based multi-output when `num_agg_partitions > 1`. + /// Each output partition gets a receiver; input tasks send partitioned + /// batches to the appropriate output channel. + repartition_state: Arc>, } impl AggregateExec { @@ -692,6 +739,7 @@ impl AggregateExec { input_schema: Arc::clone(&self.input_schema), dynamic_filter: self.dynamic_filter.clone(), num_agg_partitions: self.num_agg_partitions, + repartition_state: Arc::new(Mutex::new(AggRepartitionState::default())), } } @@ -713,6 +761,7 @@ impl AggregateExec { input_schema: Arc::clone(&self.input_schema), dynamic_filter: self.dynamic_filter.clone(), num_agg_partitions: self.num_agg_partitions, + repartition_state: Arc::new(Mutex::new(AggRepartitionState::default())), } } @@ -848,6 +897,7 @@ impl AggregateExec { cache: Arc::new(cache), dynamic_filter: None, num_agg_partitions: 1, + repartition_state: Arc::new(Mutex::new(AggRepartitionState::default())), }; exec.init_dynamic_filter(); @@ -861,8 +911,17 @@ impl AggregateExec { } /// Set the number of internal hash table partitions for partial aggregation. + /// When n > 1, the aggregate also acts as a repartitioning operator, + /// producing n output partitions. pub fn with_num_agg_partitions(mut self, n: usize) -> Self { self.num_agg_partitions = n; + if n > 1 { + // Update output partitioning to reflect the repartitioned output. + let group_exprs = self.output_group_expr(); + let mut cache = (*self.cache).clone(); + cache.partitioning = Partitioning::Hash(group_exprs, n); + self.cache = Arc::new(cache); + } self } @@ -933,6 +992,51 @@ impl AggregateExec { )?)) } + /// Pulls batches from a single input partition's aggregate stream and + /// routes each batch to the correct output partition channel. + /// + /// Because the internal hash tables use the same hash function + /// (`REPARTITION_RANDOM_STATE`) and partition count as the output, + /// each emitted batch already belongs entirely to one output partition. + /// We read `last_emitted_partition` from the stream to route without + /// re-hashing. + async fn pull_and_route( + mut stream: GroupedHashAggregateStream, + num_partitions: usize, + senders: Vec>>, + ) { + let mut batches_until_yield = num_partitions; + while let Some(result) = stream.next().await { + match result { + Ok(batch) => { + if batch.num_rows() == 0 { + continue; + } + let partition = stream.last_emitted_partition; + if senders[partition].send(Ok(batch)).is_err() { + // Receiver dropped (e.g. LIMIT), stop + return; + } + + if batches_until_yield == 0 { + tokio::task::yield_now().await; + batches_until_yield = num_partitions; + } else { + batches_until_yield -= 1; + } + } + Err(e) => { + let e = Arc::new(e); + for sender in &senders { + let _ = + sender.send(Err(DataFusionError::from(&e))); + } + return; + } + } + } + } + /// Finds the DataType and SortDirection for this Aggregate, if there is one pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; @@ -1445,6 +1549,14 @@ impl ExecutionPlan for AggregateExec { )?; me.limit_options = self.limit_options; me.dynamic_filter = self.dynamic_filter.clone(); + me.num_agg_partitions = self.num_agg_partitions; + if self.num_agg_partitions > 1 { + let group_exprs = me.output_group_expr(); + let mut cache = (*me.cache).clone(); + cache.partitioning = + Partitioning::Hash(group_exprs, self.num_agg_partitions); + me.cache = Arc::new(cache); + } Ok(Arc::new(me)) } @@ -1454,8 +1566,72 @@ impl ExecutionPlan for AggregateExec { partition: usize, context: Arc, ) -> Result { - self.execute_typed(partition, &context) - .map(|stream| stream.into()) + if self.num_agg_partitions <= 1 { + return self + .execute_typed(partition, &context) + .map(|stream| stream.into()); + } + + // Channel-based multi-output: N input partitions → T output partitions. + let mut state = self.repartition_state.lock(); + if matches!(*state, AggRepartitionState::NotInitialized) { + let num_output = self.num_agg_partitions; + let num_input = self.input.output_partitioning().partition_count(); + + // Create one unbounded channel per output partition. + // Unbounded is safe here: aggregate output is much smaller + // than input, and bounded channels risk deadlock when a + // single input task drains partitions sequentially. + let mut senders = Vec::with_capacity(num_output); + let mut receivers = Vec::with_capacity(num_output); + for _ in 0..num_output { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + senders.push(tx); + receivers.push(Some(rx)); + } + + // Spawn one task per input partition + let mut tasks = Vec::with_capacity(num_input); + for input_idx in 0..num_input { + let stream = GroupedHashAggregateStream::new( + self, &context, input_idx, + )?; + let senders_clone: Vec<_> = + senders.iter().map(|s| s.clone()).collect(); + let n = num_output; + + let task = SpawnedTask::spawn(Self::pull_and_route( + stream, + n, + senders_clone, + )); + tasks.push(task); + } + + *state = AggRepartitionState::Consuming { + receivers, + _abort_helper: tasks, + }; + } + + // Take this partition's receiver + let AggRepartitionState::Consuming { receivers, .. } = &mut *state + else { + unreachable!() + }; + + let rx = receivers + .get_mut(partition) + .and_then(|r| r.take()) + .expect("partition receiver already consumed"); + let schema = Arc::clone(&self.schema); + drop(state); + + // Wrap the receiver as a RecordBatchStream + let stream = futures::stream::unfold(rx, |mut rx| async move { + rx.recv().await.map(|batch| (batch, rx)) + }); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } fn metrics(&self) -> Option { @@ -4156,4 +4332,116 @@ mod tests { Ok(()) } + + /// Test that when `num_agg_partitions > 1`, the aggregate produces + /// multiple output partitions via channels, each containing correctly + /// hash-partitioned data. + #[tokio::test] + async fn test_repartitioned_aggregate_multi_output() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("group_col", DataType::Int32, false), + Field::new("value_col", DataType::Float64, false), + ])); + + // Create 2 input partitions with overlapping group keys + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 3])), + Arc::new(Float64Array::from(vec![ + 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, + ])), + ], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(Float64Array::from(vec![100.0, 200.0, 300.0, 400.0])), + ], + )?; + let input_partitions = vec![vec![batch1], vec![batch2]]; + + let task_ctx = Arc::new(TaskContext::default()); + + let group_expr = + vec![(col("group_col", &schema)?, "group_col".to_string())]; + let aggr_expr = vec![Arc::new( + AggregateExprBuilder::new( + sum_udaf(), + vec![col("value_col", &schema)?], + ) + .schema(Arc::clone(&schema)) + .alias("sum_value") + .build()?, + )]; + + let num_output_partitions = 3; + + let exec = TestMemoryExec::try_new( + &input_partitions, + Arc::clone(&schema), + None, + )?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec = Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(group_expr), + aggr_expr, + vec![None], + exec, + Arc::clone(&schema), + )? + .with_num_agg_partitions(num_output_partitions), + ); + + // Verify output partitioning is Hash with num_output_partitions + assert_eq!( + aggregate_exec.properties().output_partitioning().partition_count(), + num_output_partitions + ); + + // Collect from all output partitions + let mut all_results: std::collections::HashMap = + std::collections::HashMap::new(); + + for partition in 0..num_output_partitions { + let stream = + aggregate_exec.execute(partition, Arc::clone(&task_ctx))?; + let batches: Vec<_> = stream + .collect::>() + .await + .into_iter() + .collect::>>()?; + + for batch in &batches { + let groups = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let sums = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + *all_results.entry(groups.value(i)).or_insert(0.0) += + sums.value(i); + } + } + } + + // Verify: group 1 -> 10+40+100=150, group 2 -> 20+50+200=270, + // group 3 -> 30+60+300=390, group 4 -> 400 + assert_eq!(all_results.len(), 4); + assert_eq!(all_results[&1], 150.0); + assert_eq!(all_results[&2], 270.0); + assert_eq!(all_results[&3], 390.0); + assert_eq!(all_results[&4], 400.0); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 56dbcf5d882b8..811e3b389f685 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -434,6 +434,11 @@ pub(crate) struct GroupedHashAggregateStream { /// Tracks which partition to emit next during final output drain. emit_partition_idx: usize, + /// The internal partition index that produced the current + /// `ProducingOutput` batch. Used by channel-based multi-output + /// to route batches to the correct output partition without re-hashing. + pub(super) last_emitted_partition: usize, + // ======================================================================== // TASK-SPECIFIC STATES: // Inner states groups together properties, states for a specific task. @@ -702,6 +707,7 @@ impl GroupedHashAggregateStream { Vec::new() }, emit_partition_idx: num_agg_partitions, + last_emitted_partition: 0, aggregate_arguments, filter_expressions, group_by: agg_group_by, @@ -1216,6 +1222,7 @@ impl GroupedHashAggregateStream { if let Some(batch) = self.emit(largest, EmitTo::All, false)? { + self.last_emitted_partition = largest; self.clear_partition(largest, self.batch_size); Ok(Some(ExecutionState::ProducingOutput(batch))) } else { @@ -1429,6 +1436,7 @@ impl GroupedHashAggregateStream { let idx = self.emit_partition_idx; self.emit_partition_idx += 1; if let Some(batch) = self.emit(idx, EmitTo::All, false)? { + self.last_emitted_partition = idx; return Ok(ExecutionState::ProducingOutput(batch)); } } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 081f10d482e1e..d6fcba365b1b6 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -551,7 +551,7 @@ impl BatchPartitioner { /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions, /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve /// this (so we don't need to clone the entire implementation). - fn partition_iter( + pub fn partition_iter( &mut self, batch: RecordBatch, ) -> Result> + Send + '_> { From 8a69b5aa0fa5a6ebdb1f65ebf13a0b9f5f12c625 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 6 Mar 2026 22:50:35 +0100 Subject: [PATCH 4/6] Fix channel-based multi-output routing and output partitioning - Fix ProducingOutput to carry partition index alongside batch, ensuring correct routing when emit_next_partition eagerly advances the index - Add use_channels() method to centralize the decision of when to use the channel-based multi-output path - Add update_cache_partitioning() to keep output partitioning in sync when limit_options or num_agg_partitions change - Fix with_new_limit_options to recalculate output partitioning (prevents TopK aggregation from claiming Hash partitioning when channels won't be used) - Guard CombinePartialFinalAggregate from combining when Partial has num_agg_partitions > 1 - Set num_agg_partitions in physical planner when repartitioning is enabled - Keep spawned task references alive via Arc in output streams to prevent abort-on-drop Co-Authored-By: Claude Opus 4.6 --- datafusion/core/src/physical_planner.rs | 38 +++-- .../src/combine_partial_final_agg.rs | 1 + .../physical-plan/src/aggregates/mod.rs | 122 ++++++++------- .../physical-plan/src/aggregates/row_hash.rs | 147 +++++++----------- 4 files changed, 152 insertions(+), 156 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 8d52f555bbf76..524426f597f6b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1043,19 +1043,39 @@ impl DefaultPhysicalPlanner { input_exec }; - let initial_aggr = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - groups.clone(), - aggregates, - filters.clone(), - input_exec, - Arc::clone(&physical_input_schema), - )?); - let can_repartition = !groups.is_empty() && session_state.config().target_partitions() > 1 && session_state.config().repartition_aggregations(); + let initial_aggr = if can_repartition { + // When repartitioning is enabled, set num_agg_partitions + // so the partial aggregate produces target_partitions + // output partitions directly (eliminating the need for + // a downstream RepartitionExec). + Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates, + filters.clone(), + input_exec, + Arc::clone(&physical_input_schema), + )? + .with_num_agg_partitions( + session_state.config().target_partitions(), + ), + ) + } else { + Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates, + filters.clone(), + input_exec, + Arc::clone(&physical_input_schema), + )?) + }; + // Some aggregators may be modified during initialization for // optimization purposes. For example, a FIRST_VALUE may turn // into a LAST_VALUE with the reverse ordering requirement. diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs index 860406118c1b7..f112c98cbba82 100644 --- a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -73,6 +73,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { }; let transformed = if *input_agg_exec.mode() == AggregateMode::Partial + && input_agg_exec.num_agg_partitions() <= 1 && can_combine( ( agg_exec.group_expr(), diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 32b7477fdda1b..6c8c6f7b1369a 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -634,8 +634,7 @@ enum AggRepartitionState { /// Input tasks have been spawned and are sending to output channels. Consuming { /// One receiver per output partition. `None` if already taken by `execute()`. - receivers: - Vec>>>, + receivers: Vec>>>, /// Background tasks; dropped to abort if the exec is dropped. _abort_helper: Vec>, }, @@ -745,7 +744,7 @@ impl AggregateExec { /// Clone this exec, overriding only the limit hint. pub fn with_new_limit_options(&self, limit_options: Option) -> Self { - Self { + let mut new = Self { limit_options, // clone the rest of the fields required_input_ordering: self.required_input_ordering.clone(), @@ -762,7 +761,9 @@ impl AggregateExec { dynamic_filter: self.dynamic_filter.clone(), num_agg_partitions: self.num_agg_partitions, repartition_state: Arc::new(Mutex::new(AggRepartitionState::default())), - } + }; + new.update_cache_partitioning(); + new } pub fn cache(&self) -> &PlanProperties { @@ -910,27 +911,55 @@ impl AggregateExec { &self.mode } + /// Number of internal hash table partitions. + pub fn num_agg_partitions(&self) -> usize { + self.num_agg_partitions + } + /// Set the number of internal hash table partitions for partial aggregation. /// When n > 1, the aggregate also acts as a repartitioning operator, /// producing n output partitions. pub fn with_num_agg_partitions(mut self, n: usize) -> Self { self.num_agg_partitions = n; - if n > 1 { - // Update output partitioning to reflect the repartitioned output. - let group_exprs = self.output_group_expr(); - let mut cache = (*self.cache).clone(); - cache.partitioning = Partitioning::Hash(group_exprs, n); - self.cache = Arc::new(cache); - } + self.update_cache_partitioning(); self } /// Set the limit options for this AggExec pub fn with_limit_options(mut self, limit_options: Option) -> Self { self.limit_options = limit_options; + self.update_cache_partitioning(); self } + /// Returns true if the channel-based multi-output path will be used + /// in execute(). + fn use_channels(&self) -> bool { + self.num_agg_partitions > 1 + && !self.group_by.is_true_no_grouping() + && (self.limit_options.is_none() + || self.is_unordered_unfiltered_group_by_distinct()) + } + + /// Update the cached output partitioning based on whether channels + /// will be used. + fn update_cache_partitioning(&mut self) { + if self.use_channels() { + let group_exprs = self.output_group_expr(); + let mut cache = (*self.cache).clone(); + cache.partitioning = Partitioning::Hash(group_exprs, self.num_agg_partitions); + self.cache = Arc::new(cache); + } else if self.num_agg_partitions > 1 { + // Channels won't be used (e.g. limit/TopK was set), revert + // to single-partition output matching the non-channel path. + let mut cache = (*self.cache).clone(); + cache.partitioning = Partitioning::UnknownPartitioning( + self.input.output_partitioning().partition_count(), + ); + self.cache = Arc::new(cache); + } + } + /// Get the limit options (if set) pub fn limit_options(&self) -> Option { self.limit_options @@ -1028,8 +1057,7 @@ impl AggregateExec { Err(e) => { let e = Arc::new(e); for sender in &senders { - let _ = - sender.send(Err(DataFusionError::from(&e))); + let _ = sender.send(Err(DataFusionError::from(&e))); } return; } @@ -1553,8 +1581,7 @@ impl ExecutionPlan for AggregateExec { if self.num_agg_partitions > 1 { let group_exprs = me.output_group_expr(); let mut cache = (*me.cache).clone(); - cache.partitioning = - Partitioning::Hash(group_exprs, self.num_agg_partitions); + cache.partitioning = Partitioning::Hash(group_exprs, self.num_agg_partitions); me.cache = Arc::new(cache); } @@ -1566,7 +1593,7 @@ impl ExecutionPlan for AggregateExec { partition: usize, context: Arc, ) -> Result { - if self.num_agg_partitions <= 1 { + if !self.use_channels() { return self .execute_typed(partition, &context) .map(|stream| stream.into()); @@ -1593,18 +1620,12 @@ impl ExecutionPlan for AggregateExec { // Spawn one task per input partition let mut tasks = Vec::with_capacity(num_input); for input_idx in 0..num_input { - let stream = GroupedHashAggregateStream::new( - self, &context, input_idx, - )?; - let senders_clone: Vec<_> = - senders.iter().map(|s| s.clone()).collect(); + let stream = GroupedHashAggregateStream::new(self, &context, input_idx)?; + let senders_clone: Vec<_> = senders.iter().map(|s| s.clone()).collect(); let n = num_output; - let task = SpawnedTask::spawn(Self::pull_and_route( - stream, - n, - senders_clone, - )); + let task = + SpawnedTask::spawn(Self::pull_and_route(stream, n, senders_clone)); tasks.push(task); } @@ -1615,8 +1636,7 @@ impl ExecutionPlan for AggregateExec { } // Take this partition's receiver - let AggRepartitionState::Consuming { receivers, .. } = &mut *state - else { + let AggRepartitionState::Consuming { receivers, .. } = &mut *state else { unreachable!() }; @@ -1625,12 +1645,17 @@ impl ExecutionPlan for AggregateExec { .and_then(|r| r.take()) .expect("partition receiver already consumed"); let schema = Arc::clone(&self.schema); + // Keep a reference to the shared state so spawned tasks aren't + // aborted when the AggregateExec is dropped but streams are + // still being polled. + let _state_ref = Arc::clone(&self.repartition_state); drop(state); // Wrap the receiver as a RecordBatchStream - let stream = futures::stream::unfold(rx, |mut rx| async move { - rx.recv().await.map(|batch| (batch, rx)) - }); + let stream = + futures::stream::unfold((rx, _state_ref), |(mut rx, state_ref)| async move { + rx.recv().await.map(|batch| (batch, (rx, state_ref))) + }); Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) } @@ -4348,9 +4373,7 @@ mod tests { Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 2, 3, 1, 2, 3])), - Arc::new(Float64Array::from(vec![ - 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, - ])), + Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])), ], )?; let batch2 = RecordBatch::try_new( @@ -4364,25 +4387,17 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); - let group_expr = - vec![(col("group_col", &schema)?, "group_col".to_string())]; + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; let aggr_expr = vec![Arc::new( - AggregateExprBuilder::new( - sum_udaf(), - vec![col("value_col", &schema)?], - ) - .schema(Arc::clone(&schema)) - .alias("sum_value") - .build()?, + AggregateExprBuilder::new(sum_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("sum_value") + .build()?, )]; let num_output_partitions = 3; - let exec = TestMemoryExec::try_new( - &input_partitions, - Arc::clone(&schema), - None, - )?; + let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); let aggregate_exec = Arc::new( @@ -4399,7 +4414,10 @@ mod tests { // Verify output partitioning is Hash with num_output_partitions assert_eq!( - aggregate_exec.properties().output_partitioning().partition_count(), + aggregate_exec + .properties() + .output_partitioning() + .partition_count(), num_output_partitions ); @@ -4408,8 +4426,7 @@ mod tests { std::collections::HashMap::new(); for partition in 0..num_output_partitions { - let stream = - aggregate_exec.execute(partition, Arc::clone(&task_ctx))?; + let stream = aggregate_exec.execute(partition, Arc::clone(&task_ctx))?; let batches: Vec<_> = stream .collect::>() .await @@ -4428,8 +4445,7 @@ mod tests { .downcast_ref::() .unwrap(); for i in 0..batch.num_rows() { - *all_results.entry(groups.value(i)).or_insert(0.0) += - sums.value(i); + *all_results.entry(groups.value(i)).or_insert(0.0) += sums.value(i); } } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 811e3b389f685..40ac750803485 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -65,8 +65,10 @@ use log::debug; pub(crate) enum ExecutionState { ReadingInput, /// When producing output, the remaining rows to output are stored - /// here and are sliced off as needed in batch_size chunks - ProducingOutput(RecordBatch), + /// here and are sliced off as needed in batch_size chunks. + /// The `usize` is the internal partition index that produced this batch + /// (used by channel-based multi-output routing). + ProducingOutput(RecordBatch, usize), /// Produce intermediate aggregate state for each input row without /// aggregation. /// @@ -801,7 +803,7 @@ impl Stream for GroupedHashAggregateStream { timer.done(); if let Some(batch) = self.emit(0, to_emit, false)? { self.exec_state = - ExecutionState::ProducingOutput(batch); + ExecutionState::ProducingOutput(batch, 0); }; // make sure the exec_state just set is not overwritten below break 'reading_input; @@ -822,7 +824,6 @@ impl Stream for GroupedHashAggregateStream { self.exec_state = new_state; break 'reading_input; } - } // If we reach this point, try to update the memory reservation @@ -884,13 +885,12 @@ impl Stream for GroupedHashAggregateStream { } } - ExecutionState::ProducingOutput(_) => { + ExecutionState::ProducingOutput(_, _) => { // Take ownership of the batch to allow mutable self access below. - let batch = std::mem::replace( - &mut self.exec_state, - ExecutionState::Done, - ); - let ExecutionState::ProducingOutput(batch) = batch else { + let state = + std::mem::replace(&mut self.exec_state, ExecutionState::Done); + let ExecutionState::ProducingOutput(batch, batch_partition) = state + else { unreachable!() }; // slice off a part of the batch, if needed @@ -921,7 +921,10 @@ impl Stream for GroupedHashAggregateStream { let num_remaining = batch.num_rows() - size; let remaining = batch.slice(size, num_remaining); let output = batch.slice(0, size); - (ExecutionState::ProducingOutput(remaining), output) + ( + ExecutionState::ProducingOutput(remaining, batch_partition), + output, + ) }; if let Some(reduction_factor) = self.reduction_factor.as_ref() { @@ -931,6 +934,7 @@ impl Stream for GroupedHashAggregateStream { // Empty record batches should not be emitted. // They need to be treated as [`Option`]es and handled separately debug_assert!(output_batch.num_rows() > 0); + self.last_emitted_partition = batch_partition; return Poll::Ready(Some(Ok( output_batch.record_output(&self.baseline_metrics) ))); @@ -994,8 +998,7 @@ impl GroupedHashAggregateStream { // Evaluate the filter expressions, if any, against the inputs let filter_values = if self.spill_state.is_stream_merging { - let filter_expressions = - vec![None; self.partitions[0].accumulators.len()]; + let filter_expressions = vec![None; self.partitions[0].accumulators.len()]; evaluate_optional(&filter_expressions, batch)? } else { evaluate_optional(&self.filter_expressions, batch)? @@ -1051,12 +1054,9 @@ impl GroupedHashAggregateStream { )?; let num_partitions = self.partitions.len(); - self.partition_indices - .iter_mut() - .for_each(|v| v.clear()); + self.partition_indices.iter_mut().for_each(|v| v.clear()); for (row_idx, hash) in self.hash_buffer.iter().enumerate() { - self.partition_indices - [(*hash % num_partitions as u64) as usize] + self.partition_indices[(*hash % num_partitions as u64) as usize] .push(row_idx as u32); } @@ -1074,8 +1074,7 @@ impl GroupedHashAggregateStream { ); // Take sub-batch for this partition - let part_groups = - take_arrays(group_values, &indices_array, None)?; + let part_groups = take_arrays(group_values, &indices_array, None)?; let part_inputs: Vec> = input_values .iter() .map(|vals| { @@ -1102,10 +1101,9 @@ impl GroupedHashAggregateStream { .collect::>()?; let partition = &mut self.partitions[p]; - partition.group_values.intern( - &part_groups, - &mut partition.current_group_indices, - )?; + partition + .group_values + .intern(&part_groups, &mut partition.current_group_indices)?; Self::accumulate_partition( partition, @@ -1144,23 +1142,13 @@ impl GroupedHashAggregateStream { let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); if is_raw_input { - acc.update_batch( - values, - group_indices, - opt_filter, - total_num_groups, - )?; + acc.update_batch(values, group_indices, opt_filter, total_num_groups)?; } else { assert_or_internal_err!( opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage" ); - acc.merge_batch( - values, - group_indices, - None, - total_num_groups, - )?; + acc.merge_batch(values, group_indices, None, total_num_groups)?; } } Ok(()) @@ -1189,9 +1177,7 @@ impl GroupedHashAggregateStream { } OutOfMemoryMode::EmitEarly if self.total_group_len() > 1 => { if self.partitions.len() == 1 { - let n = if self.partitions[0].group_values.len() - >= self.batch_size - { + let n = if self.partitions[0].group_values.len() >= self.batch_size { self.partitions[0].group_values.len() / self.batch_size * self.batch_size } else { @@ -1209,22 +1195,18 @@ impl GroupedHashAggregateStream { }; if n > 0 - && let Some(batch) = - self.emit(0, EmitTo::First(n), false)? + && let Some(batch) = self.emit(0, EmitTo::First(n), false)? { - Ok(Some(ExecutionState::ProducingOutput(batch))) + Ok(Some(ExecutionState::ProducingOutput(batch, 0))) } else { Err(oom) } } else { // Multi-partition: emit the largest partition let largest = self.largest_partition_idx(); - if let Some(batch) = - self.emit(largest, EmitTo::All, false)? - { - self.last_emitted_partition = largest; + if let Some(batch) = self.emit(largest, EmitTo::All, false)? { self.clear_partition(largest, self.batch_size); - Ok(Some(ExecutionState::ProducingOutput(batch))) + Ok(Some(ExecutionState::ProducingOutput(batch, largest))) } else { Err(oom) } @@ -1247,13 +1229,12 @@ impl GroupedHashAggregateStream { acc + groups_size + self.group_ordering.size() + indices_size; // Reserve extra headroom for sorting during potential spill. - let sort_headroom = if self.oom_mode == OutOfMemoryMode::Spill - && !self.all_partitions_empty() - { - acc + groups_size - } else { - 0 - }; + let sort_headroom = + if self.oom_mode == OutOfMemoryMode::Spill && !self.all_partitions_empty() { + acc + groups_size + } else { + 0 + }; let new_size = groups_and_acc_size + sort_headroom; let reservation_result = self.reservation.try_resize(new_size); @@ -1419,10 +1400,7 @@ impl GroupedHashAggregateStream { .iter() .enumerate() .max_by_key(|(_, p)| { - p.accumulators - .iter() - .map(|a| a.size()) - .sum::() + p.accumulators.iter().map(|a| a.size()).sum::() + p.group_values.size() }) .map(|(i, _)| i) @@ -1436,8 +1414,7 @@ impl GroupedHashAggregateStream { let idx = self.emit_partition_idx; self.emit_partition_idx += 1; if let Some(batch) = self.emit(idx, EmitTo::All, false)? { - self.last_emitted_partition = idx; - return Ok(ExecutionState::ProducingOutput(batch)); + return Ok(ExecutionState::ProducingOutput(batch, idx)); } } Ok(ExecutionState::Done) @@ -1466,7 +1443,9 @@ impl GroupedHashAggregateStream { if self.partitions.len() == 1 { // Single partition: flush remaining group values. let batch = self.emit(0, EmitTo::All, false)?; - batch.map_or(ExecutionState::Done, ExecutionState::ProducingOutput) + batch.map_or(ExecutionState::Done, |b| { + ExecutionState::ProducingOutput(b, 0) + }) } else { // Multi-partition: emit partitions one at a time. self.emit_partition_idx = 0; @@ -1973,24 +1952,16 @@ mod tests { let input_partitions = vec![vec![batch]]; let task_ctx = Arc::new(TaskContext::default()); - let group_expr = - vec![(col("group_col", &schema)?, "group_col".to_string())]; + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; let aggr_expr = vec![Arc::new( - AggregateExprBuilder::new( - count_udaf(), - vec![col("value_col", &schema)?], - ) - .schema(Arc::clone(&schema)) - .alias("count_value") - .build()?, + AggregateExprBuilder::new(count_udaf(), vec![col("value_col", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count_value") + .build()?, )]; // Run with 1 partition (baseline) - let exec = TestMemoryExec::try_new( - &input_partitions, - Arc::clone(&schema), - None, - )?; + let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); let aggregate_exec_1 = AggregateExec::try_new( @@ -2002,11 +1973,8 @@ mod tests { Arc::clone(&schema), )?; - let mut stream_1 = GroupedHashAggregateStream::new( - &aggregate_exec_1, - &task_ctx, - 0, - )?; + let mut stream_1 = + GroupedHashAggregateStream::new(&aggregate_exec_1, &task_ctx, 0)?; let mut results_1 = std::collections::HashMap::new(); while let Some(Ok(batch)) = stream_1.next().await { let groups = batch @@ -2020,17 +1988,12 @@ mod tests { .downcast_ref::() .unwrap(); for i in 0..batch.num_rows() { - *results_1.entry(groups.value(i)).or_insert(0i64) += - counts.value(i); + *results_1.entry(groups.value(i)).or_insert(0i64) += counts.value(i); } } // Run with 4 partitions - let exec = TestMemoryExec::try_new( - &input_partitions, - Arc::clone(&schema), - None, - )?; + let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); let aggregate_exec_4 = AggregateExec::try_new( @@ -2043,11 +2006,8 @@ mod tests { )? .with_num_agg_partitions(4); - let mut stream_4 = GroupedHashAggregateStream::new( - &aggregate_exec_4, - &task_ctx, - 0, - )?; + let mut stream_4 = + GroupedHashAggregateStream::new(&aggregate_exec_4, &task_ctx, 0)?; let mut results_4 = std::collections::HashMap::new(); while let Some(Ok(batch)) = stream_4.next().await { let groups = batch @@ -2061,8 +2021,7 @@ mod tests { .downcast_ref::() .unwrap(); for i in 0..batch.num_rows() { - *results_4.entry(groups.value(i)).or_insert(0i64) += - counts.value(i); + *results_4.entry(groups.value(i)).or_insert(0i64) += counts.value(i); } } From 155db34e5d02b5820397acb750febb91c55bf468 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sat, 7 Mar 2026 21:30:41 +0100 Subject: [PATCH 5/6] Add intern_with_indices and update_batch_with_indices to avoid take in partitioned aggregation Replace take_arrays calls in the multi-partition aggregation path with index-based methods that pass row indices directly to GroupValues and GroupsAccumulator implementations. - Add intern_with_indices to GroupValues trait with precomputed hashes (default uses take) - Add update_batch_with_indices/merge_batch_with_indices to GroupsAccumulator (default uses take) - Native intern_with_indices for GroupValuesRows, GroupValuesPrimitive, GroupValuesBoolean - Primitive reuses precomputed hashes; row converter reuses precomputed hashes - Compute aggregation hashes once in row_hash.rs and pass to all partitions - Extract shared take_values_and_filter helper, intern_value helpers to reduce duplication - Merge accumulate_partition and accumulate_partition_with_indices into single method Co-Authored-By: Claude Opus 4.6 --- .../expr-common/src/groups_accumulator.rs | 77 ++++++++++++- .../src/aggregates/group_values/mod.rs | 26 ++++- .../src/aggregates/group_values/row.rs | 49 ++++++++ .../group_values/single_group_by/boolean.rs | 65 ++++++----- .../group_values/single_group_by/primitive.rs | 109 +++++++++++++----- .../physical-plan/src/aggregates/row_hash.rs | 83 ++++++------- 6 files changed, 304 insertions(+), 105 deletions(-) diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 08c9f01f13c40..9c71f8f7c4fc6 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -17,8 +17,9 @@ //! Vectorized [`GroupsAccumulator`] -use arrow::array::{ArrayRef, BooleanArray}; -use datafusion_common::{Result, not_impl_err}; +use arrow::array::{Array, ArrayRef, BooleanArray, UInt32Array}; +use arrow::compute::{take, take_arrays}; +use datafusion_common::{DataFusionError, Result, not_impl_err}; /// Describes how many rows should be emitted during grouping. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -198,6 +199,54 @@ pub trait GroupsAccumulator: Send { total_num_groups: usize, ) -> Result<()>; + /// Like [`Self::update_batch`], but only processes the rows at the given + /// `indices` positions in `values`. + /// + /// `group_indices` has one entry per index in `indices` (in the same order). + /// `opt_filter` if present is full-length (covers all rows in `values`). + /// + /// The default implementation uses `take` to extract sub-arrays, then + /// delegates to [`Self::update_batch`]. + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let (taken_values, taken_filter) = + take_values_and_filter(values, indices, opt_filter)?; + self.update_batch( + &taken_values, + group_indices, + taken_filter.as_ref(), + total_num_groups, + ) + } + + /// Like [`Self::merge_batch`], but only processes the rows at the given + /// `indices` positions in `values`. + /// + /// See [`Self::update_batch_with_indices`] for details on parameters. + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let (taken_values, taken_filter) = + take_values_and_filter(values, indices, opt_filter)?; + self.merge_batch( + &taken_values, + group_indices, + taken_filter.as_ref(), + total_num_groups, + ) + } + /// Converts an input batch directly to the intermediate aggregate state. /// /// This is the equivalent of treating each input row as its own group. It @@ -254,3 +303,27 @@ pub trait GroupsAccumulator: Send { /// compute, not `O(num_groups)` fn size(&self) -> usize; } + +/// Extracts sub-arrays and an optional filter at the given `indices`. +fn take_values_and_filter( + values: &[ArrayRef], + indices: &[u32], + opt_filter: Option<&BooleanArray>, +) -> Result<(Vec, Option)> { + let indices_array = UInt32Array::from_iter_values(indices.iter().copied()); + let taken_values = + take_arrays(values, &indices_array, None).map_err(DataFusionError::from)?; + let taken_filter = opt_filter + .map(|f| { + let taken = take(f, &indices_array, None)?; + Ok::<_, DataFusionError>( + taken + .as_any() + .downcast_ref::() + .unwrap() + .clone(), + ) + }) + .transpose()?; + Ok((taken_values, taken_filter)) +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 2f3b1a19e7d73..d7a91b89da9be 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -22,7 +22,8 @@ use arrow::array::types::{ Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow::array::{ArrayRef, downcast_primitive}; +use arrow::array::{ArrayRef, UInt32Array, downcast_primitive}; +use arrow::compute::take_arrays; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use datafusion_common::Result; @@ -99,6 +100,29 @@ pub trait GroupValues: Send { /// assigned. fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + /// Like [`Self::intern`], but only processes the rows at the given + /// `indices` positions in `cols`. + /// + /// `hashes` contains precomputed hashes for ALL rows in `cols` (full-length). + /// Implementations that need hashes can index into it using the provided `indices`. + /// + /// When the function returns, `groups` contains one group id per + /// index in `indices` (in the same order). + /// + /// The default implementation uses `take` to extract sub-arrays. + fn intern_with_indices( + &mut self, + cols: &[ArrayRef], + hashes: &[u64], + indices: &[u32], + groups: &mut Vec, + ) -> Result<()> { + let _ = hashes; // unused in default impl + let indices_array = UInt32Array::from_iter_values(indices.iter().copied()); + let sub_cols = take_arrays(cols, &indices_array, None)?; + self.intern(&sub_cols, groups) + } + /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index dd794c957350d..6f73ccf754fe7 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -172,6 +172,55 @@ impl GroupValues for GroupValuesRows { Ok(()) } + fn intern_with_indices( + &mut self, + cols: &[ArrayRef], + hashes: &[u64], + indices: &[u32], + groups: &mut Vec, + ) -> Result<()> { + // Convert all group keys into the row format + let group_rows = &mut self.rows_buffer; + group_rows.clear(); + self.row_converter.append(group_rows, cols)?; + + let mut group_values = match self.group_values.take() { + Some(group_values) => group_values, + None => self.row_converter.empty_rows(0, 0), + }; + + groups.clear(); + + // Use precomputed hashes — no need to call create_hashes + for &idx in indices { + let row = idx as usize; + let target_hash = hashes[row]; + let entry = self.map.find_mut(target_hash, |(exist_hash, group_idx)| { + target_hash == *exist_hash + && group_rows.row(row) == group_values.row(*group_idx) + }); + + let group_idx = match entry { + Some((_hash, group_idx)) => *group_idx, + None => { + let group_idx = group_values.num_rows(); + group_values.push(group_rows.row(row)); + self.map.insert_accounted( + (target_hash, group_idx), + |(hash, _group_index)| *hash, + &mut self.map_size, + ); + group_idx + } + }; + groups.push(group_idx); + } + + self.group_values = Some(group_values); + + Ok(()) + } + fn size(&self) -> usize { let group_values_size = self.group_values.as_ref().map(|v| v.size()).unwrap_or(0); self.row_converter.size() diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index e993c0c53d199..17a229535e003 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -18,7 +18,7 @@ use crate::aggregates::group_values::GroupValues; use arrow::array::{ - ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder, NullBufferBuilder, + Array, ArrayRef, AsArray as _, BooleanArray, BooleanBufferBuilder, NullBufferBuilder, }; use datafusion_common::Result; use datafusion_expr::EmitTo; @@ -39,6 +39,16 @@ impl GroupValuesBoolean { null_group: None, } } + + #[inline(always)] + fn intern_value(&mut self, value: Option) -> usize { + let next_id = self.len(); + match value { + Some(false) => *self.false_group.get_or_insert(next_id), + Some(true) => *self.true_group.get_or_insert(next_id), + None => *self.null_group.get_or_insert(next_id), + } + } } impl GroupValues for GroupValuesBoolean { @@ -47,37 +57,30 @@ impl GroupValues for GroupValuesBoolean { groups.clear(); for value in array.iter() { - let index = match value { - Some(false) => { - if let Some(index) = self.false_group { - index - } else { - let index = self.len(); - self.false_group = Some(index); - index - } - } - Some(true) => { - if let Some(index) = self.true_group { - index - } else { - let index = self.len(); - self.true_group = Some(index); - index - } - } - None => { - if let Some(index) = self.null_group { - index - } else { - let index = self.len(); - self.null_group = Some(index); - index - } - } - }; + groups.push(self.intern_value(value)); + } + + Ok(()) + } + + fn intern_with_indices( + &mut self, + cols: &[ArrayRef], + _hashes: &[u64], + indices: &[u32], + groups: &mut Vec, + ) -> Result<()> { + let array = cols[0].as_boolean(); + groups.clear(); - groups.push(index); + for &idx in indices { + let idx = idx as usize; + let value = if array.is_null(idx) { + None + } else { + Some(array.value(idx)) + }; + groups.push(self.intern_value(value)); } Ok(()) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 2b8a2cfa68897..e0dbed6056b82 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -19,8 +19,8 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::{ - ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, PrimitiveArray, - cast::AsArray, + Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, + PrimitiveArray, cast::AsArray, }; use arrow::datatypes::{DataType, i256}; use datafusion_common::Result; @@ -108,6 +108,60 @@ impl GroupValuesPrimitive { } } +impl GroupValuesPrimitive +where + T::Native: HashValue, +{ + /// Intern a single value and return its group index. + #[inline(always)] + fn intern_value(&mut self, v: Option) -> usize { + match v { + None => self.intern_null(), + Some(key) => { + let hash = key.hash(&self.random_state); + self.intern_key(key, hash) + } + } + } + + /// Intern a single value with a precomputed hash. + #[inline(always)] + fn intern_value_with_hash(&mut self, v: Option, hash: u64) -> usize { + match v { + None => self.intern_null(), + Some(key) => self.intern_key(key, hash), + } + } + + #[inline(always)] + fn intern_null(&mut self) -> usize { + *self.null_group.get_or_insert_with(|| { + let group_id = self.values.len(); + self.values.push(Default::default()); + group_id + }) + } + + #[inline(always)] + fn intern_key(&mut self, key: T::Native, hash: u64) -> usize { + let insert = self.map.entry( + hash, + |&(g, h)| unsafe { hash == h && self.values.get_unchecked(g).is_eq(key) }, + |&(_, h)| h, + ); + + match insert { + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, + hashbrown::hash_table::Entry::Vacant(v) => { + let g = self.values.len(); + v.insert((g, hash)); + self.values.push(key); + g + } + } + } +} + impl GroupValues for GroupValuesPrimitive where T::Native: HashValue, @@ -117,35 +171,30 @@ where groups.clear(); for v in cols[0].as_primitive::() { - let group_id = match v { - None => *self.null_group.get_or_insert_with(|| { - let group_id = self.values.len(); - self.values.push(Default::default()); - group_id - }), - Some(key) => { - let state = &self.random_state; - let hash = key.hash(state); - let insert = self.map.entry( - hash, - |&(g, h)| unsafe { - hash == h && self.values.get_unchecked(g).is_eq(key) - }, - |&(_, h)| h, - ); - - match insert { - hashbrown::hash_table::Entry::Occupied(o) => o.get().0, - hashbrown::hash_table::Entry::Vacant(v) => { - let g = self.values.len(); - v.insert((g, hash)); - self.values.push(key); - g - } - } - } + groups.push(self.intern_value(v)); + } + Ok(()) + } + + fn intern_with_indices( + &mut self, + cols: &[ArrayRef], + hashes: &[u64], + indices: &[u32], + groups: &mut Vec, + ) -> Result<()> { + assert_eq!(cols.len(), 1); + groups.clear(); + + let arr = cols[0].as_primitive::(); + for &idx in indices { + let idx = idx as usize; + let v = if arr.is_null(idx) { + None + } else { + Some(arr.value(idx)) }; - groups.push(group_id) + groups.push(self.intern_value_with_hash(v, hashes[idx])); } Ok(()) } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 40ac750803485..2b9bf47c68cb1 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -50,9 +50,7 @@ use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; -use crate::repartition::REPARTITION_RANDOM_STATE; use crate::sorts::IncrementalSortIterator; -use arrow::compute::take_arrays; use datafusion_common::hash_utils::create_hashes; use datafusion_common::instant::Instant; use datafusion_common::utils::memory::get_record_batch_memory_size; @@ -1036,6 +1034,7 @@ impl GroupedHashAggregateStream { partition, &input_values, &filter_values, + None, is_raw_input, )?; @@ -1049,7 +1048,7 @@ impl GroupedHashAggregateStream { self.hash_buffer.resize(num_rows, 0); create_hashes( group_values, - REPARTITION_RANDOM_STATE.random_state(), + &super::AGGREGATION_HASH_SEED, &mut self.hash_buffer, )?; @@ -1069,46 +1068,21 @@ impl GroupedHashAggregateStream { if self.partition_indices[p].is_empty() { continue; } - let indices_array = UInt32Array::from_iter_values( - self.partition_indices[p].iter().copied(), - ); - - // Take sub-batch for this partition - let part_groups = take_arrays(group_values, &indices_array, None)?; - let part_inputs: Vec> = input_values - .iter() - .map(|vals| { - take_arrays(vals.as_slice(), &indices_array, None) - .map_err(DataFusionError::from) - }) - .collect::>()?; - let part_filters: Vec> = filter_values - .iter() - .map(|opt_f| { - opt_f - .as_ref() - .map(|f| { - take_arrays( - std::slice::from_ref(f), - &indices_array, - None, - ) - .map(|mut v| v.remove(0)) - .map_err(DataFusionError::from) - }) - .transpose() - }) - .collect::>()?; + let indices = &self.partition_indices[p]; let partition = &mut self.partitions[p]; - partition - .group_values - .intern(&part_groups, &mut partition.current_group_indices)?; + partition.group_values.intern_with_indices( + group_values, + &self.hash_buffer, + indices, + &mut partition.current_group_indices, + )?; Self::accumulate_partition( partition, - &part_inputs, - &part_filters, + &input_values, + &filter_values, + Some(indices), is_raw_input, )?; } @@ -1122,11 +1096,13 @@ impl GroupedHashAggregateStream { Ok(()) } - /// Update accumulators for a single partition. + /// Update accumulators for a single partition. When `indices` is `Some`, + /// only the rows at those positions in the input arrays are processed. fn accumulate_partition( partition: &mut PartitionAggState, input_values: &[Vec], filter_values: &[Option], + indices: Option<&[u32]>, is_raw_input: bool, ) -> Result<()> { let group_indices = &partition.current_group_indices; @@ -1142,13 +1118,38 @@ impl GroupedHashAggregateStream { let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); if is_raw_input { - acc.update_batch(values, group_indices, opt_filter, total_num_groups)?; + if let Some(indices) = indices { + acc.update_batch_with_indices( + values, + indices, + group_indices, + opt_filter, + total_num_groups, + )?; + } else { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } } else { assert_or_internal_err!( opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage" ); - acc.merge_batch(values, group_indices, None, total_num_groups)?; + if let Some(indices) = indices { + acc.merge_batch_with_indices( + values, + indices, + group_indices, + None, + total_num_groups, + )?; + } else { + acc.merge_batch(values, group_indices, None, total_num_groups)?; + } } } Ok(()) From d2733aafe0c622e4aa83923f4e2ff82cc44e93cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Sun, 8 Mar 2026 05:55:13 +0100 Subject: [PATCH 6/6] Implement more --- .../groups_accumulator/accumulate.rs | 210 ++++++++++++++++++ .../aggregate/groups_accumulator/bool_op.rs | 50 +++++ .../aggregate/groups_accumulator/prim_op.rs | 46 ++++ .../functions-aggregate/src/array_agg.rs | 86 +++++++ .../functions-aggregate/src/correlation.rs | 82 ++++++- datafusion/functions-aggregate/src/count.rs | 53 +++++ .../src/min_max/min_max_bytes.rs | 142 ++++++++++++ .../src/min_max/min_max_struct.rs | 84 +++++++ datafusion/functions-aggregate/src/stddev.rs | 34 +++ .../functions-aggregate/src/variance.rs | 74 +++++- 10 files changed, 859 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 25f52df61136f..07eb7484a06ea 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -288,6 +288,114 @@ impl NullState { } } + /// Like [`Self::accumulate`], but only processes the rows at the given + /// `indices` positions in `values`. + /// + /// `group_indices[i]` corresponds to `values[indices[i]]`. + pub fn accumulate_with_indices( + &mut self, + group_indices: &[usize], + values: &PrimitiveArray, + indices: &[u32], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, T::Native) + Send, + { + if let SeenValues::All { num_values } = &mut self.seen_values + && opt_filter.is_none() + && values.null_count() == 0 + { + accumulate_with_indices(group_indices, values, indices, None, value_fn); + *num_values = total_num_groups; + return; + } + + let seen_values = self.seen_values.get_builder(total_num_groups); + accumulate_with_indices( + group_indices, + values, + indices, + opt_filter, + |group_index, value| { + seen_values.set_bit(group_index, true); + value_fn(group_index, value); + }, + ); + } + + /// Like [`Self::accumulate_boolean`], but only processes the rows at the + /// given `indices` positions in `values`. + /// + /// `group_indices[i]` corresponds to `values[indices[i]]`. + pub fn accumulate_boolean_with_indices( + &mut self, + group_indices: &[usize], + values: &BooleanArray, + indices: &[u32], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + F: FnMut(usize, bool) + Send, + { + let data = values.values(); + assert_eq!(group_indices.len(), indices.len()); + + if let SeenValues::All { num_values } = &mut self.seen_values + && opt_filter.is_none() + && values.null_count() == 0 + { + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + value_fn(group_index, data.value(idx as usize)); + } + *num_values = total_num_groups; + return; + } + + let seen_values = self.seen_values.get_builder(total_num_groups); + + match (values.null_count() > 0, opt_filter) { + (false, None) => { + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + seen_values.set_bit(group_index, true); + value_fn(group_index, data.value(idx as usize)); + } + } + (true, None) => { + let nulls = values.nulls().unwrap(); + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if nulls.is_valid(idx) { + seen_values.set_bit(group_index, true); + value_fn(group_index, data.value(idx)); + } + } + } + (false, Some(filter)) => { + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if !filter.is_null(idx) && filter.value(idx) { + seen_values.set_bit(group_index, true); + value_fn(group_index, data.value(idx)); + } + } + } + (true, Some(filter)) => { + let nulls = values.nulls().unwrap(); + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if nulls.is_valid(idx) && !filter.is_null(idx) && filter.value(idx) { + seen_values.set_bit(group_index, true); + value_fn(group_index, data.value(idx)); + } + } + } + } + } + /// Creates the a [`NullBuffer`] representing which group_indices /// should have null values (because they never saw any values) /// for the `emit_to` rows. @@ -673,6 +781,108 @@ pub fn accumulate_indices( } } +/// Like [`accumulate`], but only processes the rows at the given `indices` +/// positions in `values`. +/// +/// `group_indices[i]` corresponds to `values[indices[i]]`. +/// `opt_filter`, if present, is full-length and is checked at `indices[i]`. +pub fn accumulate_with_indices( + group_indices: &[usize], + values: &PrimitiveArray, + indices: &[u32], + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, T::Native) + Send, +{ + let data: &[T::Native] = values.values(); + assert_eq!(group_indices.len(), indices.len()); + + match (values.null_count() > 0, opt_filter) { + (false, None) => { + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + value_fn(group_index, data[idx as usize]); + } + } + (true, None) => { + let nulls = values.nulls().unwrap(); + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if nulls.is_valid(idx) { + value_fn(group_index, data[idx]); + } + } + } + (false, Some(filter)) => { + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if !filter.is_null(idx) && filter.value(idx) { + value_fn(group_index, data[idx]); + } + } + } + (true, Some(filter)) => { + let nulls = values.nulls().unwrap(); + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if nulls.is_valid(idx) && !filter.is_null(idx) && filter.value(idx) { + value_fn(group_index, data[idx]); + } + } + } + } +} + +/// Like [`accumulate_multiple`], but only processes the rows at the given +/// `indices` positions in `value_columns`. +/// +/// `group_indices[i]` corresponds to row `indices[i]` in all value columns. +pub fn accumulate_multiple_with_indices( + group_indices: &[usize], + value_columns: &[&PrimitiveArray], + indices: &[u32], + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, usize, &[&PrimitiveArray]) + Send, +{ + assert_eq!(group_indices.len(), indices.len()); + + let combined_nulls = value_columns + .iter() + .map(|arr| arr.logical_nulls()) + .fold(None, |acc, nulls| { + NullBuffer::union(acc.as_ref(), nulls.as_ref()) + }); + + let valid_indices = match (combined_nulls, opt_filter) { + (None, None) => None, + (None, Some(filter)) => Some(filter.clone()), + (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)), + (Some(nulls), Some(filter)) => { + let combined = nulls.inner() & filter.values(); + Some(BooleanArray::new(combined, None)) + } + }; + + match valid_indices { + None => { + for (&group_idx, &idx) in group_indices.iter().zip(indices.iter()) { + value_fn(group_idx, idx as usize, value_columns); + } + } + Some(valid_indices) => { + for (&group_idx, &idx) in group_indices.iter().zip(indices.iter()) { + if valid_indices.value(idx as usize) { + value_fn(group_idx, idx as usize, value_columns); + } + } + } + } +} + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs index f716b48f0cccc..52be2d2788ffd 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs @@ -104,6 +104,38 @@ where Ok(()) } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_boolean(); + + if self.values.len() < total_num_groups { + let new_groups = total_num_groups - self.values.len(); + self.values.append_n(new_groups, self.identity); + } + + self.null_state.accumulate_boolean_with_indices( + group_indices, + values, + indices, + opt_filter, + total_num_groups, + |group_index, new_value| { + let current_value = self.values.get_bit(group_index); + let value = (self.bool_fn)(current_value, new_value); + self.values.set_bit(group_index, value); + }, + ); + + Ok(()) + } + fn evaluate(&mut self, emit_to: EmitTo) -> Result { let values = self.values.finish(); @@ -139,6 +171,24 @@ where self.update_batch(values, group_indices, opt_filter, total_num_groups) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // update / merge are the same + self.update_batch_with_indices( + values, + indices, + group_indices, + opt_filter, + total_num_groups, + ) + } + fn size(&self) -> usize { // capacity is in bits, so convert to bytes self.values.capacity() / 8 + self.null_state.size() diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index acf875b686139..7c515de1f098e 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -115,6 +115,34 @@ where Ok(()) } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + + self.values.resize(total_num_groups, self.starting_value); + + self.null_state.accumulate_with_indices( + group_indices, + values, + indices, + opt_filter, + total_num_groups, + |group_index, new_value| { + let value = unsafe { self.values.get_unchecked_mut(group_index) }; + (self.prim_fn)(value, new_value); + }, + ); + + Ok(()) + } + fn evaluate(&mut self, emit_to: EmitTo) -> Result { let values = emit_to.take_needed(&mut self.values); let nulls = self.null_state.build(emit_to); @@ -138,6 +166,24 @@ where self.update_batch(values, group_indices, opt_filter, total_num_groups) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // update / merge are the same + self.update_batch_with_indices( + values, + indices, + group_indices, + opt_filter, + total_num_groups, + ) + } + /// Converts an input batch directly to a state batch /// /// The state is: diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index cd4cb9b19ff77..3398a03311003 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -592,6 +592,54 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator { Ok(()) } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let input = &values[0]; + + self.num_groups = self.num_groups.max(total_num_groups); + + let nulls = if self.ignore_nulls { + input.logical_nulls() + } else { + None + }; + + let mut entries = Vec::new(); + + for (&group_idx, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + // Skip filtered rows + if let Some(filter) = opt_filter + && (filter.is_null(idx) || !filter.value(idx)) + { + continue; + } + + // Skip null values when ignore_nulls is set + if let Some(ref nulls) = nulls + && nulls.is_null(idx) + { + continue; + } + + entries.push((group_idx as u32, idx as u32)); + } + + if !entries.is_empty() { + self.batches.push(Arc::clone(input)); + self.batch_entries.push(entries); + } + + Ok(()) + } + /// Produce a `ListArray` ordered by group index: the list at /// position N contains the aggregated values for group N. /// @@ -711,6 +759,44 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator { Ok(()) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + let input_list = values[0].as_list::(); + + self.num_groups = self.num_groups.max(total_num_groups); + + let list_values = input_list.values(); + let list_offsets = input_list.offsets(); + + let mut entries = Vec::new(); + + for (&group_idx, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if input_list.is_null(idx) { + continue; + } + let start = list_offsets[idx] as u32; + let end = list_offsets[idx + 1] as u32; + for pos in start..end { + entries.push((group_idx as u32, pos)); + } + } + + if !entries.is_empty() { + self.batches.push(Arc::clone(list_values)); + self.batch_entries.push(entries); + } + + Ok(()) + } + fn convert_to_state( &self, values: &[ArrayRef], diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs index 6c76c6e940099..57f975a6092e3 100644 --- a/datafusion/functions-aggregate/src/correlation.rs +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -33,7 +33,9 @@ use arrow::{ datatypes::{DataType, Field}, }; use datafusion_expr::{EmitTo, GroupsAccumulator}; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_multiple; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::{ + accumulate_multiple, accumulate_multiple_with_indices, +}; use log::debug; use crate::covariance::CovarianceAccumulator; @@ -410,6 +412,44 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { Ok(()) } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + let array_x = downcast_array::(&values[0]); + let array_y = downcast_array::(&values[1]); + + accumulate_multiple_with_indices( + group_indices, + &[&array_x, &array_y], + indices, + opt_filter, + |group_index, batch_index, columns| { + let x = columns[0].value(batch_index); + let y = columns[1].value(batch_index); + self.count[group_index] += 1; + self.sum_x[group_index] += x; + self.sum_y[group_index] += y; + self.sum_xy[group_index] += x * y; + self.sum_xx[group_index] += x * x; + self.sum_yy[group_index] += y * y; + }, + ); + + Ok(()) + } + fn evaluate(&mut self, emit_to: EmitTo) -> Result { // Drain the state vectors for the groups being emitted let counts = emit_to.take_needed(&mut self.count); @@ -542,6 +582,46 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator { Ok(()) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.count.resize(total_num_groups, 0); + self.sum_x.resize(total_num_groups, 0.0); + self.sum_y.resize(total_num_groups, 0.0); + self.sum_xy.resize(total_num_groups, 0.0); + self.sum_xx.resize(total_num_groups, 0.0); + self.sum_yy.resize(total_num_groups, 0.0); + + let partial_counts = values[0].as_primitive::(); + let partial_sum_x = values[1].as_primitive::(); + let partial_sum_y = values[2].as_primitive::(); + let partial_sum_xy = values[3].as_primitive::(); + let partial_sum_xx = values[4].as_primitive::(); + let partial_sum_yy = values[5].as_primitive::(); + + assert!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); + + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + self.count[group_index] += partial_counts.value(idx); + self.sum_x[group_index] += partial_sum_x.value(idx); + self.sum_y[group_index] += partial_sum_y.value(idx); + self.sum_xy[group_index] += partial_sum_xy.value(idx); + self.sum_xx[group_index] += partial_sum_xx.value(idx); + self.sum_yy[group_index] += partial_sum_yy.value(idx); + } + + Ok(()) + } + fn size(&self) -> usize { self.count.capacity() * size_of::() + self.sum_x.capacity() * size_of::() diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 376cf39745903..5b529bec93d14 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -598,6 +598,38 @@ impl GroupsAccumulator for CountGroupsAccumulator { Ok(()) } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &values[0]; + + self.counts.resize(total_num_groups, 0); + let nulls = values.logical_nulls(); + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if let Some(ref nulls) = nulls { + if !nulls.is_valid(idx) { + continue; + } + } + if let Some(filter) = opt_filter { + if filter.is_null(idx) || !filter.value(idx) { + continue; + } + } + let count = unsafe { self.counts.get_unchecked_mut(group_index) }; + *count += 1; + } + + Ok(()) + } + fn merge_batch( &mut self, values: &[ArrayRef], @@ -625,6 +657,27 @@ impl GroupsAccumulator for CountGroupsAccumulator { Ok(()) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + let partial_counts = values[0].as_primitive::(); + assert_eq!(partial_counts.null_count(), 0); + let partial_counts = partial_counts.values(); + + self.counts.resize(total_num_groups, 0); + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + self.counts[group_index] += partial_counts[idx as usize]; + } + + Ok(()) + } + fn evaluate(&mut self, emit_to: EmitTo) -> Result { let counts = emit_to.take_needed(&mut self.counts); diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs index e4ac7eccf5692..023d7327a68a0 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -199,6 +199,130 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { } } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.data_type(), &self.inner.data_type); + + // Helper: check if row at `idx` should be included (not null, not filtered) + let is_valid_row = |idx: usize| -> bool { + if array.is_null(idx) { + return false; + } + if let Some(f) = opt_filter { + if f.is_null(idx) || !f.value(idx) { + return false; + } + } + true + }; + + fn string_min(a: &[u8], b: &[u8]) -> bool { + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a < b + } + } + fn string_max(a: &[u8], b: &[u8]) -> bool { + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a > b + } + } + fn binary_min(a: &[u8], b: &[u8]) -> bool { + a < b + } + fn binary_max(a: &[u8], b: &[u8]) -> bool { + a > b + } + + // Create indexed iterator that yields Option<&[u8]> for each + // (group_index, index) pair + macro_rules! indexed_str_update { + ($array_accessor:expr, $cmp:expr) => {{ + let typed_array = $array_accessor; + let iter = indices.iter().map(|&idx| { + let idx = idx as usize; + if is_valid_row(idx) { + Some(typed_array.value(idx).as_bytes()) + } else { + None + } + }); + self.inner + .update_batch(iter, group_indices, total_num_groups, $cmp) + }}; + } + + macro_rules! indexed_bin_update { + ($array_accessor:expr, $cmp:expr) => {{ + let typed_array = $array_accessor; + let iter = indices.iter().map(|&idx| { + let idx = idx as usize; + if is_valid_row(idx) { + Some(typed_array.value(idx)) + } else { + None + } + }); + self.inner + .update_batch(iter, group_indices, total_num_groups, $cmp) + }}; + } + + match (self.is_min, &self.inner.data_type) { + (true, &DataType::Utf8) => { + indexed_str_update!(array.as_string::(), string_min) + } + (true, &DataType::LargeUtf8) => { + indexed_str_update!(array.as_string::(), string_min) + } + (true, &DataType::Utf8View) => { + indexed_str_update!(array.as_string_view(), string_min) + } + (false, &DataType::Utf8) => { + indexed_str_update!(array.as_string::(), string_max) + } + (false, &DataType::LargeUtf8) => { + indexed_str_update!(array.as_string::(), string_max) + } + (false, &DataType::Utf8View) => { + indexed_str_update!(array.as_string_view(), string_max) + } + (true, &DataType::Binary) => { + indexed_bin_update!(array.as_binary::(), binary_min) + } + (true, &DataType::LargeBinary) => { + indexed_bin_update!(array.as_binary::(), binary_min) + } + (true, &DataType::BinaryView) => { + indexed_bin_update!(array.as_binary_view(), binary_min) + } + (false, &DataType::Binary) => { + indexed_bin_update!(array.as_binary::(), binary_max) + } + (false, &DataType::LargeBinary) => { + indexed_bin_update!(array.as_binary::(), binary_max) + } + (false, &DataType::BinaryView) => { + indexed_bin_update!(array.as_binary_view(), binary_max) + } + _ => internal_err!( + "Unexpected combination for MinMaxBytesAccumulator: ({:?}, {:?})", + self.is_min, + self.inner.data_type + ), + } + } + fn evaluate(&mut self, emit_to: EmitTo) -> Result { let (data_capacity, min_maxes) = self.inner.emit_to(emit_to); @@ -314,6 +438,24 @@ impl GroupsAccumulator for MinMaxBytesAccumulator { self.update_batch(values, group_indices, opt_filter, total_num_groups) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch_with_indices( + values, + indices, + group_indices, + opt_filter, + total_num_groups, + ) + } + fn convert_to_state( &self, values: &[ArrayRef], diff --git a/datafusion/functions-aggregate/src/min_max/min_max_struct.rs b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs index 796fd586ca5c8..1c5ddeea7af6d 100644 --- a/datafusion/functions-aggregate/src/min_max/min_max_struct.rs +++ b/datafusion/functions-aggregate/src/min_max/min_max_struct.rs @@ -98,6 +98,72 @@ impl GroupsAccumulator for MinMaxStructAccumulator { } } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.data_type(), &self.inner.data_type); + + fn struct_min(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Less)) + } + + fn struct_max(a: &StructArray, b: &StructArray) -> bool { + matches!(partial_cmp_struct(a, b), Some(Ordering::Greater)) + } + + let cmp_fn: fn(&StructArray, &StructArray) -> bool = + if self.is_min { struct_min } else { struct_max }; + + let struct_array = array.as_struct(); + self.inner.min_max.resize(total_num_groups, None); + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + if array.is_null(idx) { + continue; + } + if let Some(filter) = opt_filter { + if filter.is_null(idx) || !filter.value(idx) { + continue; + } + } + let new_val = struct_array.slice(idx, 1); + + let existing_val = match &locations[group_index] { + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(existing_val) = self.inner.min_max[group_index].as_ref() + else { + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + existing_val + } + }; + + if cmp_fn(&new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => { + self.inner.set_value(group_index, new_val) + } + } + } + Ok(()) + } + fn evaluate(&mut self, emit_to: EmitTo) -> Result { let (_, min_maxes) = self.inner.emit_to(emit_to); let fields = match &self.inner.data_type { @@ -139,6 +205,24 @@ impl GroupsAccumulator for MinMaxStructAccumulator { self.update_batch(values, group_indices, opt_filter, total_num_groups) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch_with_indices( + values, + indices, + group_indices, + opt_filter, + total_num_groups, + ) + } + fn convert_to_state( &self, values: &[ArrayRef], diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 6f77e7df92547..0364bf02b7eac 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -336,6 +336,23 @@ impl GroupsAccumulator for StddevGroupsAccumulator { .update_batch(values, group_indices, opt_filter, total_num_groups) } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.variance.update_batch_with_indices( + values, + indices, + group_indices, + opt_filter, + total_num_groups, + ) + } + fn merge_batch( &mut self, values: &[ArrayRef], @@ -347,6 +364,23 @@ impl GroupsAccumulator for StddevGroupsAccumulator { .merge_batch(values, group_indices, opt_filter, total_num_groups) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.variance.merge_batch_with_indices( + values, + indices, + group_indices, + opt_filter, + total_num_groups, + ) + } + fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { let (mut variances, nulls) = self.variance.variance(emit_to); variances.iter_mut().for_each(|v| *v = v.sqrt()); diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index fb089ba4f9cea..453f39693b8cd 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -34,7 +34,8 @@ use datafusion_expr::{ }; use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; use datafusion_functions_aggregate_common::{ - aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, + aggregate::groups_accumulator::accumulate::{accumulate, accumulate_with_indices}, + stats::StatsType, }; use datafusion_macros::user_doc; use std::mem::{size_of, size_of_val}; @@ -525,6 +526,38 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { Ok(()) } + fn update_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = as_float64_array(&values[0])?; + + self.resize(total_num_groups); + accumulate_with_indices( + group_indices, + values, + indices, + opt_filter, + |group_index, value| { + let (new_count, new_mean, new_m2) = update( + self.counts[group_index], + self.means[group_index], + self.m2s[group_index], + value, + ); + self.counts[group_index] = new_count; + self.means[group_index] = new_mean; + self.m2s[group_index] = new_m2; + }, + ); + Ok(()) + } + fn merge_batch( &mut self, values: &[ArrayRef], @@ -566,6 +599,45 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { Ok(()) } + fn merge_batch_with_indices( + &mut self, + values: &[ArrayRef], + indices: &[u32], + group_indices: &[usize], + _opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 3, "three arguments to merge_batch"); + let partial_counts = as_uint64_array(&values[0])?; + let partial_means = as_float64_array(&values[1])?; + let partial_m2s = as_float64_array(&values[2])?; + + assert_eq!(partial_counts.null_count(), 0); + assert_eq!(partial_means.null_count(), 0); + assert_eq!(partial_m2s.null_count(), 0); + + self.resize(total_num_groups); + for (&group_index, &idx) in group_indices.iter().zip(indices.iter()) { + let idx = idx as usize; + let partial_count = partial_counts.value(idx); + if partial_count == 0 { + continue; + } + let (new_count, new_mean, new_m2) = merge( + self.counts[group_index], + self.means[group_index], + self.m2s[group_index], + partial_count, + partial_means.value(idx), + partial_m2s.value(idx), + ); + self.counts[group_index] = new_count; + self.means[group_index] = new_mean; + self.m2s[group_index] = new_m2; + } + Ok(()) + } + fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { let (variances, nulls) = self.variance(emit_to); Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls))))