diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs index ea1a87d091481..250e3f91c3a6d 100644 --- a/datafusion/physical-plan/src/coalesce/mod.rs +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::RecordBatch; +use arrow::array::{PrimitiveArray, RecordBatch}; use arrow::compute::BatchCoalescer; -use arrow::datatypes::SchemaRef; -use datafusion_common::{Result, assert_or_internal_err}; +use arrow::datatypes::{SchemaRef, UInt32Type}; +use datafusion_common::{DataFusionError, Result, assert_or_internal_err}; /// Concatenate multiple [`RecordBatch`]es and apply a limit /// @@ -138,6 +138,45 @@ impl LimitedBatchCoalescer { self.finished } + /// Push a batch with indices into the coalescer, selecting only the rows + /// indicated by `indices`. Returns the push status. + pub fn push_batch_with_indices( + &mut self, + batch: RecordBatch, + indices: &PrimitiveArray, + ) -> Result { + assert_or_internal_err!( + !self.finished, + "LimitedBatchCoalescer: cannot push batch after finish" + ); + + let num_rows = indices.len(); + + if let Some(fetch) = self.fetch { + if self.total_rows >= fetch { + return Ok(PushBatchStatus::LimitReached); + } + + if self.total_rows + num_rows >= fetch { + let remaining_rows = fetch - self.total_rows; + debug_assert!(remaining_rows > 0); + let indices = indices.slice(0, remaining_rows); + self.total_rows += remaining_rows; + self.inner + .push_batch_with_indices(batch, &indices) + .map_err(DataFusionError::from)?; + return Ok(PushBatchStatus::LimitReached); + } + } + + self.total_rows += num_rows; + self.inner + .push_batch_with_indices(batch, indices) + .map_err(DataFusionError::from)?; + + Ok(PushBatchStatus::Continue) + } + /// Return the next completed batch, if any pub fn next_completed_batch(&mut self) -> Option { self.inner.next_completed_batch() diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 081f10d482e1e..7482000609e75 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -71,7 +71,7 @@ use crate::sort_pushdown::SortOrderPushdownResult; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::stream::Stream; -use futures::{FutureExt, StreamExt, TryStreamExt, ready}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use log::trace; use parking_lot::Mutex; @@ -272,6 +272,7 @@ impl RepartitionExecState { name: &str, context: &Arc, spill_manager: SpillManager, + fetch: Option, ) -> Result<&mut ConsumingInputStreamsState> { let streams_and_metrics = match self { RepartitionExecState::NotInitialized => { @@ -388,6 +389,13 @@ impl RepartitionExecState { .map(|(partition, channel)| (*partition, channel.sender.clone())) .collect(); + // Skip sender-side coalescing when preserve_order is set, + // because the subsequent merge sort already does batching. + let target_batch_size = if preserve_order { + None + } else { + Some(context.session_config().batch_size()) + }; let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( stream, txs, @@ -396,6 +404,8 @@ impl RepartitionExecState { // preserve_order depends on partition index to start from 0 if preserve_order { 0 } else { i }, num_input_partitions, + target_batch_size, + fetch, )); // In a separate task, wait for each input to be done @@ -416,9 +426,20 @@ impl RepartitionExecState { } /// A utility that can be used to partition batches based on [`Partitioning`] +/// +/// When constructed with coalescers (via [`Self::with_coalescers`]), it accumulates +/// rows per output partition using [`LimitedBatchCoalescer`] and only yields batches +/// when `target_batch_size` is reached. For hash partitioning this uses +/// `push_batch_with_indices` to avoid materializing intermediate sub-batches. pub struct BatchPartitioner { state: BatchPartitionerState, timer: metrics::Time, + /// Optional per-partition coalescers for accumulating rows. + /// When present, `partition_iter` pushes into coalescers and yields completed batches. + /// When absent, `partition_iter` yields sub-batches directly. + coalescers: Option>, + /// Buffer for completed (partition, batch) pairs returned by `partition_iter`. + completed: Vec<(usize, RecordBatch)>, } enum BatchPartitionerState { @@ -441,14 +462,6 @@ pub const REPARTITION_RANDOM_STATE: SeededRandomState = impl BatchPartitioner { /// Create a new [`BatchPartitioner`] for hash-based repartitioning. - /// - /// # Parameters - /// - `exprs`: Expressions used to compute the hash for each input row. - /// - `num_partitions`: Total number of output partitions. - /// - `timer`: Metric used to record time spent during repartitioning. - /// - /// # Notes - /// This constructor cannot fail and performs no validation. pub fn new_hash_partitioner( exprs: Vec>, num_partitions: usize, @@ -462,18 +475,13 @@ impl BatchPartitioner { indices: vec![vec![]; num_partitions], }, timer, + coalescers: None, + completed: Vec::new(), } } /// Create a new [`BatchPartitioner`] for round-robin repartitioning. /// - /// # Parameters - /// - `num_partitions`: Total number of output partitions. - /// - `timer`: Metric used to record time spent during repartitioning. - /// - `input_partition`: Index of the current input partition. - /// - `num_input_partitions`: Total number of input partitions. - /// - /// # Notes /// The starting output partition is derived from the input partition /// to avoid skew when multiple input partitions are used. pub fn new_round_robin_partitioner( @@ -488,21 +496,12 @@ impl BatchPartitioner { next_idx: (input_partition * num_partitions) / num_input_partitions, }, timer, + coalescers: None, + completed: Vec::new(), } } + /// Create a new [`BatchPartitioner`] based on the provided [`Partitioning`] scheme. - /// - /// This is a convenience constructor that delegates to the specialized - /// hash or round-robin constructors depending on the partitioning variant. - /// - /// # Parameters - /// - `partitioning`: Partitioning scheme to apply (hash or round-robin). - /// - `timer`: Metric used to record time spent during repartitioning. - /// - `input_partition`: Index of the current input partition. - /// - `num_input_partitions`: Total number of input partitions. - /// - /// # Errors - /// Returns an error if the provided partitioning scheme is not supported. pub fn try_new( partitioning: Partitioning, timer: metrics::Time, @@ -527,6 +526,40 @@ impl BatchPartitioner { } } + /// Enable coalescing with per-partition [`LimitedBatchCoalescer`]s. + /// + /// When enabled, `partition_iter` accumulates rows into coalescers and + /// only yields batches when `target_batch_size` is reached (or on `finish`). + /// For hash partitioning, this uses `push_batch_with_indices` to avoid + /// materializing intermediate sub-batches. + pub fn with_coalescers( + mut self, + schema: SchemaRef, + target_batch_size: usize, + fetch: Option, + ) -> Self { + let num_partitions = self.num_partitions(); + self.coalescers = Some( + (0..num_partitions) + .map(|_| { + LimitedBatchCoalescer::new( + Arc::clone(&schema), + target_batch_size, + fetch, + ) + }) + .collect(), + ); + self + } + + fn num_partitions(&self) -> usize { + match &self.state { + BatchPartitionerState::RoundRobin { num_partitions, .. } => *num_partitions, + BatchPartitionerState::Hash { num_partitions, .. } => *num_partitions, + } + } + /// Partition the provided [`RecordBatch`] into one or more partitioned [`RecordBatch`] /// based on the [`Partitioning`] specified on construction /// @@ -546,6 +579,31 @@ impl BatchPartitioner { }) } + /// Drain any completed batches from coalescers for the given partitions. + fn drain_completed(&mut self, partitions: impl Iterator) { + if let Some(coalescers) = &mut self.coalescers { + for partition in partitions { + while let Some(batch) = coalescers[partition].next_completed_batch() { + self.completed.push((partition, batch)); + } + } + } + } + + /// Flush all coalescers and yield remaining buffered batches. + /// Call this after all input batches have been processed. + pub fn finish( + &mut self, + ) -> Result + '_> { + if let Some(coalescers) = &mut self.coalescers { + for coalescer in coalescers.iter_mut() { + coalescer.finish()?; + } + } + self.drain_completed(0..self.num_partitions()); + Ok(self.completed.drain(..)) + } + /// Actual implementation of [`partition`](Self::partition). /// /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions, @@ -555,65 +613,101 @@ impl BatchPartitioner { &mut self, batch: RecordBatch, ) -> Result> + Send + '_> { - let it: Box> + Send> = - match &mut self.state { - BatchPartitionerState::RoundRobin { - num_partitions, - next_idx, - } => { - let idx = *next_idx; - *next_idx = (*next_idx + 1) % *num_partitions; - Box::new(std::iter::once(Ok((idx, batch)))) + self.completed.clear(); + match &mut self.state { + BatchPartitionerState::RoundRobin { + num_partitions, + next_idx, + } => { + let idx = *next_idx; + *next_idx = (*next_idx + 1) % *num_partitions; + if let Some(coalescers) = &mut self.coalescers { + coalescers[idx].push_batch(batch)?; + while let Some(completed) = + coalescers[idx].next_completed_batch() + { + self.completed.push((idx, completed)); + } + } else { + self.completed.push((idx, batch)); } - BatchPartitionerState::Hash { - exprs, - num_partitions: partitions, + } + BatchPartitionerState::Hash { + exprs, + num_partitions: partitions, + hash_buffer, + indices, + } => { + let timer = self.timer.timer(); + + let arrays = + evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?; + + hash_buffer.clear(); + hash_buffer.resize(batch.num_rows(), 0); + + create_hashes( + &arrays, + REPARTITION_RANDOM_STATE.random_state(), hash_buffer, - indices, - } => { - // Tracking time required for distributing indexes across output partitions - let timer = self.timer.timer(); + )?; - let arrays = - evaluate_expressions_to_arrays(exprs.as_slice(), &batch)?; + indices.iter_mut().for_each(|v| v.clear()); - hash_buffer.clear(); - hash_buffer.resize(batch.num_rows(), 0); + for (index, hash) in hash_buffer.iter().enumerate() { + indices[(*hash % *partitions as u64) as usize] + .push(index as u32); + } - create_hashes( - &arrays, - REPARTITION_RANDOM_STATE.random_state(), - hash_buffer, - )?; + timer.done(); - indices.iter_mut().for_each(|v| v.clear()); + if let Some(coalescers) = &mut self.coalescers { + // Push indices directly into coalescers, avoiding + // materializing intermediate sub-batches. + for (partition, p_indices) in indices.iter_mut().enumerate() { + if !p_indices.is_empty() { + let taken_indices = std::mem::take(p_indices); + let indices_array: PrimitiveArray = + taken_indices.into(); - for (index, hash) in hash_buffer.iter().enumerate() { - indices[(*hash % *partitions as u64) as usize].push(index as u32); - } + coalescers[partition].push_batch_with_indices( + batch.clone(), + &indices_array, + )?; - // Finished building index-arrays for output partitions - timer.done(); + // Return the taken vec for reuse + let (_, buffer, _) = indices_array.into_parts(); + if let Ok(mut vec) = + buffer.into_inner().into_vec::() + { + vec.clear(); + *p_indices = vec; + } - // Borrowing partitioner timer to prevent moving `self` to closure + while let Some(completed) = + coalescers[partition].next_completed_batch() + { + self.completed.push((partition, completed)); + } + } + } + } else { + // No coalescers: materialize sub-batches directly let partitioner_timer = &self.timer; - - let mut partitioned_batches = vec![]; for (partition, p_indices) in indices.iter_mut().enumerate() { if !p_indices.is_empty() { let taken_indices = std::mem::take(p_indices); let indices_array: PrimitiveArray = taken_indices.into(); - // Tracking time required for repartitioned batches construction let _timer = partitioner_timer.timer(); - // Produce batches based on indices let columns = take_arrays(batch.columns(), &indices_array, None)?; let mut options = RecordBatchOptions::new(); - options = options.with_row_count(Some(indices_array.len())); + options = + options.with_row_count(Some(indices_array.len())); let batch = RecordBatch::try_new_with_options( batch.schema(), columns, @@ -621,12 +715,14 @@ impl BatchPartitioner { ) .unwrap(); - partitioned_batches.push(Ok((partition, batch))); + self.completed.push((partition, batch)); // Return the taken vec let (_, buffer, _) = indices_array.into_parts(); - let mut vec = - buffer.into_inner().into_vec::().map_err(|e| { + let mut vec = buffer + .into_inner() + .into_vec::() + .map_err(|e| { internal_datafusion_err!( "Could not convert buffer to vec: {e:?}" ) @@ -635,20 +731,11 @@ impl BatchPartitioner { *p_indices = vec; } } - - Box::new(partitioned_batches.into_iter()) } - }; - - Ok(it) - } - - // return the number of output partitions - fn num_partitions(&self) -> usize { - match self.state { - BatchPartitionerState::RoundRobin { num_partitions, .. } => num_partitions, - BatchPartitionerState::Hash { num_partitions, .. } => num_partitions, + } } + + Ok(self.completed.drain(..).map(Ok)) } } @@ -766,6 +853,8 @@ pub struct RepartitionExec { /// Boolean flag to decide whether to preserve ordering. If true means /// `SortPreservingRepartitionExec`, false means `RepartitionExec`. preserve_order: bool, + /// Optional fetch limit (maximum number of rows per output partition) + fetch: Option, /// Cache holding plan properties like equivalences, output partitioning etc. cache: Arc, } @@ -876,6 +965,9 @@ impl DisplayAs for RepartitionExec { if let Some(sort_exprs) = self.sort_exprs() { write!(f, ", sort_exprs={}", sort_exprs.clone())?; } + if let Some(fetch) = self.fetch { + write!(f, ", fetch={fetch}")?; + } Ok(()) } DisplayFormatType::TreeRender => { @@ -942,6 +1034,7 @@ impl ExecutionPlan for RepartitionExec { if self.preserve_order { repartition = repartition.with_preserve_order(); } + repartition.fetch = self.fetch; Ok(Arc::new(repartition)) } @@ -994,6 +1087,7 @@ impl ExecutionPlan for RepartitionExec { } let num_input_partitions = input.output_partitioning().partition_count(); + let fetch = self.fetch; let stream = futures::stream::once(async move { // lock scope @@ -1008,6 +1102,7 @@ impl ExecutionPlan for RepartitionExec { &name, &context, spill_manager.clone(), + fetch, )?; // now return stream for the specified *output* partition which will @@ -1050,7 +1145,6 @@ impl ExecutionPlan for RepartitionExec { spill_stream, 1, // Each receiver handles one input partition BaselineMetrics::new(&metrics, partition), - None, // subsequent merge sort already does batching https://github.com/apache/datafusion/blob/e4dcf0c85611ad0bd291f03a8e03fe56d773eb16/datafusion/physical-plan/src/sorts/merge.rs#L286 )) as SendableRecordBatchStream }) .collect::>(); @@ -1089,7 +1183,6 @@ impl ExecutionPlan for RepartitionExec { spill_stream, num_input_partitions, BaselineMetrics::new(&metrics, partition), - Some(context.session_config().batch_size()), )) as SendableRecordBatchStream) } }) @@ -1144,7 +1237,26 @@ impl ExecutionPlan for RepartitionExec { } fn cardinality_effect(&self) -> CardinalityEffect { - CardinalityEffect::Equal + if self.fetch.is_some() { + CardinalityEffect::LowerEqual + } else { + CardinalityEffect::Equal + } + } + + fn fetch(&self) -> Option { + self.fetch + } + + fn with_fetch(&self, limit: Option) -> Option> { + Some(Arc::new(RepartitionExec { + input: Arc::clone(&self.input), + state: Arc::clone(&self.state), + metrics: self.metrics.clone(), + preserve_order: self.preserve_order, + fetch: limit, + cache: Arc::clone(&self.cache), + })) } fn try_swapping_with_projection( @@ -1243,6 +1355,7 @@ impl ExecutionPlan for RepartitionExec { state: Arc::clone(&self.state), metrics: self.metrics.clone(), preserve_order: self.preserve_order, + fetch: self.fetch, cache: new_properties.into(), }))) } @@ -1263,6 +1376,7 @@ impl RepartitionExec { state: Default::default(), metrics: ExecutionPlanMetricsSet::new(), preserve_order, + fetch: None, cache: Arc::new(cache), }) } @@ -1337,6 +1451,50 @@ impl RepartitionExec { } } + /// Send a completed batch through the output channel, handling memory + /// reservation and spilling. + /// + /// Returns `true` if the channel is still open, `false` if the receiver + /// has hung up (e.g. LIMIT). + async fn send_batch( + batch: RecordBatch, + partition: usize, + output_channels: &mut HashMap, + metrics: &RepartitionMetrics, + ) -> bool { + let Some(channel) = output_channels.get_mut(&partition) else { + return false; + }; + + let size = batch.get_array_memory_size(); + let timer = metrics.send_time[partition].timer(); + + let can_grow = channel.reservation.lock().try_grow(size).is_ok(); + let (batch_to_send, is_memory_batch) = if can_grow { + (RepartitionBatch::Memory(batch), true) + } else { + if let Err(e) = channel.spill_writer.push_batch(&batch) { + let _ = channel.sender.send(Some(Err(e.into()))).await; + timer.done(); + output_channels.remove(&partition); + return false; + } + (RepartitionBatch::Spilled, false) + }; + + if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() { + if is_memory_batch { + channel.reservation.lock().shrink(size); + } + timer.done(); + output_channels.remove(&partition); + return false; + } + + timer.done(); + true + } + /// Pulls data from the specified input plan, feeding it to the /// output partitions based on the desired partitioning /// @@ -1348,6 +1506,8 @@ impl RepartitionExec { metrics: RepartitionMetrics, input_partition: usize, num_input_partitions: usize, + target_batch_size: Option, + fetch: Option, ) -> Result<()> { let mut partitioner = match &partitioning { Partitioning::Hash(exprs, num_partitions) => { @@ -1370,81 +1530,41 @@ impl RepartitionExec { } }; + let schema = stream.schema(); + // Enable coalescers when target_batch_size is set and schema has columns. + // Skip coalescing for 0-column schemas (BatchCoalescer can't handle them) + // and when preserve_order is set (the subsequent merge sort already does batching). + if let Some(batch_size) = + target_batch_size.filter(|_| !schema.fields().is_empty()) + { + partitioner = partitioner.with_coalescers(schema, batch_size, fetch); + } + // While there are still outputs to send to, keep pulling inputs - let mut batches_until_yield = partitioner.num_partitions(); while !output_channels.is_empty() { - // fetch the next batch let timer = metrics.fetch_time.timer(); let result = stream.next().await; timer.done(); - // Input is done let batch = match result { Some(result) => result?, None => break, }; - // Handle empty batch if batch.num_rows() == 0 { continue; } for res in partitioner.partition_iter(batch)? { let (partition, batch) = res?; - let size = batch.get_array_memory_size(); - - let timer = metrics.send_time[partition].timer(); - // if there is still a receiver, send to it - if let Some(channel) = output_channels.get_mut(&partition) { - let (batch_to_send, is_memory_batch) = - match channel.reservation.lock().try_grow(size) { - Ok(_) => { - // Memory available - send in-memory batch - (RepartitionBatch::Memory(batch), true) - } - Err(_) => { - // We're memory limited - spill to SpillPool - // SpillPool handles file handle reuse and rotation - channel.spill_writer.push_batch(&batch)?; - // Send marker indicating batch was spilled - (RepartitionBatch::Spilled, false) - } - }; - - if channel.sender.send(Some(Ok(batch_to_send))).await.is_err() { - // If the other end has hung up, it was an early shutdown (e.g. LIMIT) - // Only shrink memory if it was a memory batch - if is_memory_batch { - channel.reservation.lock().shrink(size); - } - output_channels.remove(&partition); - } - } - timer.done(); + Self::send_batch(batch, partition, &mut output_channels, &metrics) + .await; } + } - // If the input stream is endless, we may spin forever and - // never yield back to tokio. See - // https://github.com/apache/datafusion/issues/5278. - // - // However, yielding on every batch causes a bottleneck - // when running with multiple cores. See - // https://github.com/apache/datafusion/issues/6290 - // - // Thus, heuristically yield after producing num_partition - // batches - // - // In round robin this is ideal as each input will get a - // new batch. In hash partitioning it may yield too often - // on uneven distributions even if some partition can not - // make progress, but parallelism is going to be limited - // in that case anyways - if batches_until_yield == 0 { - tokio::task::yield_now().await; - batches_until_yield = partitioner.num_partitions(); - } else { - batches_until_yield -= 1; - } + // Flush remaining buffered batches from coalescers + for (partition, batch) in partitioner.finish()? { + Self::send_batch(batch, partition, &mut output_channels, &metrics).await; } // Spill writers will auto-finalize when dropped @@ -1578,13 +1698,9 @@ struct PerPartitionStream { /// Execution metrics baseline_metrics: BaselineMetrics, - - /// None for sort preserving variant (merge sort already does coalescing) - batch_coalescer: Option, } impl PerPartitionStream { - #[expect(clippy::too_many_arguments)] fn new( schema: SchemaRef, receiver: DistributionReceiver, @@ -1593,10 +1709,7 @@ impl PerPartitionStream { spill_stream: SendableRecordBatchStream, num_input_partitions: usize, baseline_metrics: BaselineMetrics, - batch_size: Option, ) -> Self { - let batch_coalescer = - batch_size.map(|s| LimitedBatchCoalescer::new(Arc::clone(&schema), s, None)); Self { schema, receiver, @@ -1606,7 +1719,6 @@ impl PerPartitionStream { state: StreamState::ReadingMemory, remaining_partitions: num_input_partitions, baseline_metrics, - batch_coalescer, } } @@ -1690,43 +1802,6 @@ impl PerPartitionStream { } } } - - fn poll_next_and_coalesce( - self: &mut Pin<&mut Self>, - cx: &mut Context<'_>, - coalescer: &mut LimitedBatchCoalescer, - ) -> Poll>> { - let cloned_time = self.baseline_metrics.elapsed_compute().clone(); - let mut completed = false; - - loop { - if let Some(batch) = coalescer.next_completed_batch() { - return Poll::Ready(Some(Ok(batch))); - } - if completed { - return Poll::Ready(None); - } - - match ready!(self.poll_next_inner(cx)) { - Some(Ok(batch)) => { - let _timer = cloned_time.timer(); - if let Err(err) = coalescer.push_batch(batch) { - return Poll::Ready(Some(Err(err))); - } - } - Some(err) => { - return Poll::Ready(Some(err)); - } - None => { - completed = true; - let _timer = cloned_time.timer(); - if let Err(err) = coalescer.finish() { - return Poll::Ready(Some(Err(err))); - } - } - } - } - } } impl Stream for PerPartitionStream { @@ -1736,13 +1811,7 @@ impl Stream for PerPartitionStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll; - if let Some(mut coalescer) = self.batch_coalescer.take() { - poll = self.poll_next_and_coalesce(cx, &mut coalescer); - self.batch_coalescer = Some(coalescer); - } else { - poll = self.poll_next_inner(cx); - } + let poll = self.poll_next_inner(cx); self.baseline_metrics.record_poll(poll) } } @@ -1793,13 +1862,14 @@ mod tests { repartition(&schema, partitions, Partitioning::RoundRobinBatch(4)).await?; assert_eq!(4, output_partitions.len()); - for partition in &output_partitions { - assert_eq!(1, partition.len()); - } - assert_eq!(13 * 8, output_partitions[0][0].num_rows()); - assert_eq!(13 * 8, output_partitions[1][0].num_rows()); - assert_eq!(12 * 8, output_partitions[2][0].num_rows()); - assert_eq!(12 * 8, output_partitions[3][0].num_rows()); + let rows: Vec = output_partitions + .iter() + .map(|p| p.iter().map(|b| b.num_rows()).sum()) + .collect(); + assert_eq!(13 * 8, rows[0]); + assert_eq!(13 * 8, rows[1]); + assert_eq!(12 * 8, rows[2]); + assert_eq!(12 * 8, rows[3]); Ok(()) } @@ -1816,7 +1886,8 @@ mod tests { repartition(&schema, partitions, Partitioning::RoundRobinBatch(1)).await?; assert_eq!(1, output_partitions.len()); - assert_eq!(150 * 8, output_partitions[0][0].num_rows()); + let total_rows: usize = output_partitions[0].iter().map(|b| b.num_rows()).sum(); + assert_eq!(150 * 8, total_rows); Ok(()) } @@ -1834,9 +1905,9 @@ mod tests { let total_rows_per_partition = 8 * 50 * 3 / 5; assert_eq!(5, output_partitions.len()); - for partition in output_partitions { - assert_eq!(1, partition.len()); - assert_eq!(total_rows_per_partition, partition[0].num_rows()); + for partition in &output_partitions { + let total_rows: usize = partition.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows_per_partition, total_rows); } Ok(()) @@ -1940,9 +2011,9 @@ mod tests { let total_rows_per_partition = 8 * 50 * 3 / 5; assert_eq!(5, output_partitions.len()); - for partition in output_partitions { - assert_eq!(1, partition.len()); - assert_eq!(total_rows_per_partition, partition[0].num_rows()); + for partition in &output_partitions { + let total_rows: usize = partition.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows_per_partition, total_rows); } Ok(()) @@ -2355,9 +2426,8 @@ mod tests { let partitioning = Partitioning::RoundRobinBatch(4); // Set up context with moderate memory limit to force partial spilling - // 2KB should allow some batches in memory but force others to spill let runtime = RuntimeEnvBuilder::default() - .with_memory_limit(2 * 1024, 1.0) + .with_memory_limit(128 * 1024, 1.0) .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime);