From d0c9bbc3def4418067a135b2b43ee62ce5240e0d Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 6 Mar 2026 19:56:11 -0600 Subject: [PATCH 1/6] feat: Vectorize merging `Statistic`s --- datafusion/common/src/stats.rs | 764 ++++++++++++++++++----- datafusion/common/src/utils/aggregate.rs | 380 +++++++++++ datafusion/common/src/utils/mod.rs | 1 + 3 files changed, 996 insertions(+), 149 deletions(-) create mode 100644 datafusion/common/src/utils/aggregate.rs diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 759ebfe67a812..bc799b778593b 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -22,6 +22,9 @@ use std::fmt::{self, Debug, Display}; use crate::{Result, ScalarValue}; use crate::error::_plan_err; +use crate::utils::aggregate::{ + is_primitive_scalar, vectorized_max, vectorized_min, vectorized_sum, +}; use arrow::datatypes::{DataType, Schema}; /// Represents a value with a degree of certainty. `Precision` is used to @@ -576,23 +579,6 @@ impl Statistics { /// If not, maybe you can call `SchemaMapper::map_column_statistics` to make them consistent. /// /// Returns an error if the statistics do not match the specified schemas. - pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result - where - I: IntoIterator, - { - let mut items = items.into_iter(); - - let Some(init) = items.next() else { - return Ok(Statistics::new_unknown(schema)); - }; - items.try_fold(init.clone(), |acc: Statistics, item_stats: &Statistics| { - acc.try_merge(item_stats) - }) - } - - /// Merge this Statistics value with another Statistics value. - /// - /// Returns an error if the statistics do not match (different schemas). /// /// # Example /// ``` @@ -600,68 +586,194 @@ impl Statistics { /// # use arrow::datatypes::{Field, Schema, DataType}; /// # use datafusion_common::stats::Precision; /// let stats1 = Statistics::default() - /// .with_num_rows(Precision::Exact(1)) - /// .with_total_byte_size(Precision::Exact(2)) + /// .with_num_rows(Precision::Exact(10)) /// .add_column_statistics( /// ColumnStatistics::new_unknown() - /// .with_null_count(Precision::Exact(3)) - /// .with_min_value(Precision::Exact(ScalarValue::from(4))) - /// .with_max_value(Precision::Exact(ScalarValue::from(5))), + /// .with_min_value(Precision::Exact(ScalarValue::from(1))) + /// .with_max_value(Precision::Exact(ScalarValue::from(100))) + /// .with_sum_value(Precision::Exact(ScalarValue::from(500))), /// ); /// /// let stats2 = Statistics::default() - /// .with_num_rows(Precision::Exact(10)) - /// .with_total_byte_size(Precision::Inexact(20)) + /// .with_num_rows(Precision::Exact(20)) /// .add_column_statistics( /// ColumnStatistics::new_unknown() - /// // absent null count - /// .with_min_value(Precision::Exact(ScalarValue::from(40))) - /// .with_max_value(Precision::Exact(ScalarValue::from(50))), + /// .with_min_value(Precision::Exact(ScalarValue::from(5))) + /// .with_max_value(Precision::Exact(ScalarValue::from(200))) + /// .with_sum_value(Precision::Exact(ScalarValue::from(1000))), /// ); /// - /// let merged_stats = stats1.try_merge(&stats2).unwrap(); - /// let expected_stats = Statistics::default() - /// .with_num_rows(Precision::Exact(11)) - /// .with_total_byte_size(Precision::Inexact(22)) // inexact in stats2 --> inexact - /// .add_column_statistics( - /// ColumnStatistics::new_unknown() - /// .with_null_count(Precision::Absent) // missing from stats2 --> absent - /// .with_min_value(Precision::Exact(ScalarValue::from(4))) - /// .with_max_value(Precision::Exact(ScalarValue::from(50))), - /// ); + /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + /// let merged = Statistics::try_merge_iter( + /// &[stats1, stats2], + /// &schema, + /// ).unwrap(); /// - /// assert_eq!(merged_stats, expected_stats) + /// assert_eq!(merged.num_rows, Precision::Exact(30)); + /// assert_eq!(merged.column_statistics[0].min_value, + /// Precision::Exact(ScalarValue::from(1))); + /// assert_eq!(merged.column_statistics[0].max_value, + /// Precision::Exact(ScalarValue::from(200))); + /// assert_eq!(merged.column_statistics[0].sum_value, + /// Precision::Exact(ScalarValue::from(1500))); /// ``` - pub fn try_merge(self, other: &Statistics) -> Result { - let Self { - mut num_rows, - mut total_byte_size, - mut column_statistics, - } = self; - - // Accumulate statistics for subsequent items - num_rows = num_rows.add(&other.num_rows); - total_byte_size = total_byte_size.add(&other.total_byte_size); - - if column_statistics.len() != other.column_statistics.len() { - return _plan_err!( - "Cannot merge statistics with different number of columns: {} vs {}", - column_statistics.len(), - other.column_statistics.len() - ); + pub fn try_merge_iter<'a, I>(items: I, schema: &Schema) -> Result + where + I: IntoIterator, + { + let items: Vec<&Statistics> = items.into_iter().collect(); + + if items.is_empty() { + return Ok(Statistics::new_unknown(schema)); + } + if items.len() == 1 { + return Ok(items[0].clone()); + } + + let num_cols = items[0].column_statistics.len(); + // Validate all items have the same number of columns + for (i, stat) in items.iter().enumerate().skip(1) { + if stat.column_statistics.len() != num_cols { + return _plan_err!( + "Cannot merge statistics with different number of columns: {} vs {} (item {})", + num_cols, + stat.column_statistics.len(), + i + ); + } + } + + // Aggregate usize fields (cheap arithmetic) + let mut num_rows = Precision::Exact(0usize); + let mut total_byte_size = Precision::Exact(0usize); + for stat in &items { + num_rows = num_rows.add(&stat.num_rows); + total_byte_size = total_byte_size.add(&stat.total_byte_size); } - for (item_col_stats, col_stats) in other + // We look at the first available value type for each column to decide whether to + // use vectorized aggregation + let col_is_primitive: Vec = + (0..num_cols) + .map(|col_idx| { + items + .first() + .and_then(|s| { + s.column_statistics[col_idx].sum_value.get_value().or_else( + || s.column_statistics[col_idx].min_value.get_value(), + ) + }) + .map(is_primitive_scalar) + .unwrap_or(false) + }) + .collect(); + + let first = items[0]; + let mut column_statistics: Vec = first .column_statistics .iter() - .zip(column_statistics.iter_mut()) - { - col_stats.null_count = col_stats.null_count.add(&item_col_stats.null_count); - col_stats.max_value = col_stats.max_value.max(&item_col_stats.max_value); - col_stats.min_value = col_stats.min_value.min(&item_col_stats.min_value); - col_stats.sum_value = col_stats.sum_value.add(&item_col_stats.sum_value); - col_stats.distinct_count = Precision::Absent; - col_stats.byte_size = col_stats.byte_size.add(&item_col_stats.byte_size); + .map(|cs| ColumnStatistics { + null_count: cs.null_count, + max_value: cs.max_value.clone(), + min_value: cs.min_value.clone(), + sum_value: cs.sum_value.clone(), + distinct_count: Precision::Absent, + byte_size: cs.byte_size, + }) + .collect(); + + // fold non-primitive columns and accumulate usize + // stats for all columns. + for stat in items.iter().skip(1) { + for (col_idx, col_stats) in column_statistics.iter_mut().enumerate() { + let item_cs = &stat.column_statistics[col_idx]; + + col_stats.null_count = col_stats.null_count.add(&item_cs.null_count); + col_stats.byte_size = col_stats.byte_size.add(&item_cs.byte_size); + + // Non-primitive columns: fold sum/min/max directly using + // Precision::add/min/max + if !col_is_primitive[col_idx] { + col_stats.sum_value = col_stats.sum_value.add(&item_cs.sum_value); + col_stats.min_value = col_stats.min_value.min(&item_cs.min_value); + col_stats.max_value = col_stats.max_value.max(&item_cs.max_value); + } + } + } + + // Column-major vectorized pass: only for primitive columns. + // Collects values into Vecs and uses Arrow kernels for sum/min/max. + for col_idx in 0..num_cols { + if !col_is_primitive[col_idx] { + continue; + } + + // These booleans check whether the values were either all exact or if there + // were any inexact + let mut sum_values: Vec = Vec::with_capacity(items.len()); + let mut min_values: Vec = Vec::with_capacity(items.len()); + let mut max_values: Vec = Vec::with_capacity(items.len()); + let mut sum_all_exact = true; + let mut min_all_exact = true; + let mut max_all_exact = true; + let mut sum_any_absent = false; + let mut min_any_absent = false; + let mut max_any_absent = false; + + for stat in &items { + let cs = &stat.column_statistics[col_idx]; + + match &cs.sum_value { + Precision::Exact(v) => sum_values.push(v.clone()), + Precision::Inexact(v) => { + sum_all_exact = false; + sum_values.push(v.clone()); + } + Precision::Absent => { + sum_any_absent = true; + } + } + match &cs.min_value { + Precision::Exact(v) => min_values.push(v.clone()), + Precision::Inexact(v) => { + min_all_exact = false; + min_values.push(v.clone()); + } + Precision::Absent => { + min_any_absent = true; + } + } + match &cs.max_value { + Precision::Exact(v) => max_values.push(v.clone()), + Precision::Inexact(v) => { + max_all_exact = false; + max_values.push(v.clone()); + } + Precision::Absent => { + max_any_absent = true; + } + } + } + + let col_stats = &mut column_statistics[col_idx]; + + col_stats.sum_value = if sum_any_absent || sum_values.is_empty() { + Precision::Absent + } else { + vectorized_sum(&sum_values, sum_all_exact)? + }; + + col_stats.min_value = if min_any_absent || min_values.is_empty() { + Precision::Absent + } else { + vectorized_min(&min_values, min_all_exact)? + }; + + col_stats.max_value = if max_any_absent || max_values.is_empty() { + Precision::Absent + } else { + vectorized_max(&max_values, max_all_exact)? + }; } Ok(Statistics { @@ -1141,7 +1253,7 @@ mod tests { } #[test] - fn test_try_merge_basic() { + fn test_try_merge() { // Create a schema with two columns let schema = Arc::new(Schema::new(vec![ Field::new("col1", DataType::Int32, false), @@ -1338,52 +1450,6 @@ mod tests { ); } - #[test] - fn test_try_merge_distinct_count_absent() { - // Create statistics with known distinct counts - let stats1 = Statistics::default() - .with_num_rows(Precision::Exact(10)) - .with_total_byte_size(Precision::Exact(100)) - .add_column_statistics( - ColumnStatistics::new_unknown() - .with_null_count(Precision::Exact(0)) - .with_min_value(Precision::Exact(ScalarValue::Int32(Some(1)))) - .with_max_value(Precision::Exact(ScalarValue::Int32(Some(10)))) - .with_distinct_count(Precision::Exact(5)), - ); - - let stats2 = Statistics::default() - .with_num_rows(Precision::Exact(15)) - .with_total_byte_size(Precision::Exact(150)) - .add_column_statistics( - ColumnStatistics::new_unknown() - .with_null_count(Precision::Exact(0)) - .with_min_value(Precision::Exact(ScalarValue::Int32(Some(5)))) - .with_max_value(Precision::Exact(ScalarValue::Int32(Some(20)))) - .with_distinct_count(Precision::Exact(7)), - ); - - // Merge statistics - let merged_stats = stats1.try_merge(&stats2).unwrap(); - - // Verify the results - assert_eq!(merged_stats.num_rows, Precision::Exact(25)); - assert_eq!(merged_stats.total_byte_size, Precision::Exact(250)); - - let col_stats = &merged_stats.column_statistics[0]; - assert_eq!(col_stats.null_count, Precision::Exact(0)); - assert_eq!( - col_stats.min_value, - Precision::Exact(ScalarValue::Int32(Some(1))) - ); - assert_eq!( - col_stats.max_value, - Precision::Exact(ScalarValue::Int32(Some(20))) - ); - // Distinct count should be Absent after merge - assert_eq!(col_stats.distinct_count, Precision::Absent); - } - #[test] fn test_with_fetch_basic_preservation() { // Test that column statistics and byte size are preserved (as inexact) when applying fetch @@ -1650,44 +1716,6 @@ mod tests { assert_eq!(result_col_stats.distinct_count, Precision::Inexact(789)); } - #[test] - fn test_byte_size_try_merge() { - // Test that byte_size is summed correctly in try_merge - let col_stats1 = ColumnStatistics { - null_count: Precision::Exact(10), - max_value: Precision::Absent, - min_value: Precision::Absent, - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - byte_size: Precision::Exact(1000), - }; - let col_stats2 = ColumnStatistics { - null_count: Precision::Exact(20), - max_value: Precision::Absent, - min_value: Precision::Absent, - sum_value: Precision::Absent, - distinct_count: Precision::Absent, - byte_size: Precision::Exact(2000), - }; - - let stats1 = Statistics { - num_rows: Precision::Exact(50), - total_byte_size: Precision::Exact(1000), - column_statistics: vec![col_stats1], - }; - let stats2 = Statistics { - num_rows: Precision::Exact(100), - total_byte_size: Precision::Exact(2000), - column_statistics: vec![col_stats2], - }; - - let merged = stats1.try_merge(&stats2).unwrap(); - assert_eq!( - merged.column_statistics[0].byte_size, - Precision::Exact(3000) // 1000 + 2000 - ); - } - #[test] fn test_byte_size_to_inexact() { let col_stats = ColumnStatistics { @@ -1785,4 +1813,442 @@ mod tests { // total_byte_size should fall back to scaling: 8000 * 0.1 = 800 assert_eq!(result.total_byte_size, Precision::Inexact(800)); } + + #[test] + fn test_try_merge_iter_basic() { + let schema = Arc::new(Schema::new(vec![ + Field::new("col1", DataType::Int32, false), + Field::new("col2", DataType::Int32, false), + ])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }, + ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(200))), + min_value: Precision::Exact(ScalarValue::Int32(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(1000))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }, + ], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(15), + total_byte_size: Precision::Exact(150), + column_statistics: vec![ + ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int32(Some(120))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(600))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), + }, + ColumnStatistics { + null_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int32(Some(180))), + min_value: Precision::Exact(ScalarValue::Int32(Some(5))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(1200))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(60), + }, + ], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Exact(25)); + assert_eq!(summary_stats.total_byte_size, Precision::Exact(250)); + + let col1_stats = &summary_stats.column_statistics[0]; + assert_eq!(col1_stats.null_count, Precision::Exact(3)); + assert_eq!( + col1_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(120))) + ); + assert_eq!( + col1_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(-10))) + ); + assert_eq!( + col1_stats.sum_value, + Precision::Exact(ScalarValue::Int32(Some(1100))) + ); + + let col2_stats = &summary_stats.column_statistics[1]; + assert_eq!(col2_stats.null_count, Precision::Exact(5)); + assert_eq!( + col2_stats.max_value, + Precision::Exact(ScalarValue::Int32(Some(200))) + ); + assert_eq!( + col2_stats.min_value, + Precision::Exact(ScalarValue::Int32(Some(5))) + ); + assert_eq!( + col2_stats.sum_value, + Precision::Exact(ScalarValue::Int32(Some(2200))) + ); + } + + #[test] + fn test_try_merge_iter_mixed_precision() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(40), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Inexact(15), + total_byte_size: Precision::Exact(150), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(2), + max_value: Precision::Inexact(ScalarValue::Int32(Some(120))), + min_value: Precision::Exact(ScalarValue::Int32(Some(-10))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Inexact(25)); + assert_eq!(summary_stats.total_byte_size, Precision::Inexact(250)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Inexact(3)); + assert_eq!( + col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(120))) + ); + assert_eq!( + col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-10))) + ); + // sum_value becomes Absent because stats2 has Absent sum + assert_eq!(col_stats.sum_value, Precision::Absent); + } + + #[test] + fn test_try_merge_iter_empty() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let items: Vec<&Statistics> = vec![]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Absent); + assert_eq!(summary_stats.total_byte_size, Precision::Absent); + assert_eq!(summary_stats.column_statistics.len(), 1); + assert_eq!( + summary_stats.column_statistics[0].null_count, + Precision::Absent + ); + } + + #[test] + fn test_try_merge_iter_single_item() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int32(Some(100))), + min_value: Precision::Exact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Exact(10), + byte_size: Precision::Exact(40), + }], + }; + + let items = vec![&stats]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats, stats); + } + + #[test] + fn test_try_merge_iter_mismatched_columns() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics::default(); + let stats2 = + Statistics::default().add_column_statistics(ColumnStatistics::new_unknown()); + + let items = vec![&stats1, &stats2]; + let e = Statistics::try_merge_iter(items, &schema).unwrap_err(); + assert_contains!( + e.to_string(), + "Cannot merge statistics with different number of columns: 0 vs 1" + ); + } + + #[test] + fn test_try_merge_iter_three_items() { + // Verify that merging three items works correctly + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int64, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(1), + max_value: Precision::Exact(ScalarValue::Int64(Some(100))), + min_value: Precision::Exact(ScalarValue::Int64(Some(10))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(500))), + distinct_count: Precision::Exact(8), + byte_size: Precision::Exact(80), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(20), + total_byte_size: Precision::Exact(200), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(2), + max_value: Precision::Exact(ScalarValue::Int64(Some(200))), + min_value: Precision::Exact(ScalarValue::Int64(Some(5))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(1000))), + distinct_count: Precision::Exact(15), + byte_size: Precision::Exact(160), + }], + }; + + let stats3 = Statistics { + num_rows: Precision::Exact(30), + total_byte_size: Precision::Exact(300), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(3), + max_value: Precision::Exact(ScalarValue::Int64(Some(150))), + min_value: Precision::Exact(ScalarValue::Int64(Some(1))), + sum_value: Precision::Exact(ScalarValue::Int64(Some(2000))), + distinct_count: Precision::Exact(25), + byte_size: Precision::Exact(240), + }], + }; + + let items = vec![&stats1, &stats2, &stats3]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Exact(60)); + assert_eq!(summary_stats.total_byte_size, Precision::Exact(600)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Exact(6)); + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Int64(Some(200))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Int64(Some(1))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Exact(ScalarValue::Int64(Some(3500))) + ); + assert_eq!(col_stats.byte_size, Precision::Exact(480)); + // distinct_count is always Absent after merge (can't accurately merge NDV) + assert_eq!(col_stats.distinct_count, Precision::Absent); + } + + #[test] + fn test_try_merge_iter_float_types() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Float64, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(80), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Float64(Some(99.9))), + min_value: Precision::Exact(ScalarValue::Float64(Some(1.1))), + sum_value: Precision::Exact(ScalarValue::Float64(Some(500.5))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(80), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(80), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Float64(Some(200.0))), + min_value: Precision::Exact(ScalarValue::Float64(Some(0.5))), + sum_value: Precision::Exact(ScalarValue::Float64(Some(1000.0))), + distinct_count: Precision::Absent, + byte_size: Precision::Exact(80), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Float64(Some(200.0))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Float64(Some(0.5))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Exact(ScalarValue::Float64(Some(1500.5))) + ); + } + + #[test] + fn test_try_merge_iter_string_types() { + let schema = + Arc::new(Schema::new(vec![Field::new("col1", DataType::Utf8, false)])); + + let stats1 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Utf8(Some("dog".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("ant".to_string()))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(100), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Exact(10), + total_byte_size: Precision::Exact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(0), + max_value: Precision::Exact(ScalarValue::Utf8(Some("zebra".to_string()))), + min_value: Precision::Exact(ScalarValue::Utf8(Some("bat".to_string()))), + sum_value: Precision::Absent, + distinct_count: Precision::Absent, + byte_size: Precision::Exact(100), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!( + col_stats.max_value, + Precision::Exact(ScalarValue::Utf8(Some("zebra".to_string()))) + ); + assert_eq!( + col_stats.min_value, + Precision::Exact(ScalarValue::Utf8(Some("ant".to_string()))) + ); + assert_eq!(col_stats.sum_value, Precision::Absent); + } + + #[test] + fn test_try_merge_iter_all_inexact() { + let schema = Arc::new(Schema::new(vec![Field::new( + "col1", + DataType::Int32, + false, + )])); + + let stats1 = Statistics { + num_rows: Precision::Inexact(10), + total_byte_size: Precision::Inexact(100), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(1), + max_value: Precision::Inexact(ScalarValue::Int32(Some(100))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(1))), + sum_value: Precision::Inexact(ScalarValue::Int32(Some(500))), + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(40), + }], + }; + + let stats2 = Statistics { + num_rows: Precision::Inexact(20), + total_byte_size: Precision::Inexact(200), + column_statistics: vec![ColumnStatistics { + null_count: Precision::Inexact(2), + max_value: Precision::Inexact(ScalarValue::Int32(Some(200))), + min_value: Precision::Inexact(ScalarValue::Int32(Some(-5))), + sum_value: Precision::Inexact(ScalarValue::Int32(Some(1000))), + distinct_count: Precision::Absent, + byte_size: Precision::Inexact(60), + }], + }; + + let items = vec![&stats1, &stats2]; + let summary_stats = Statistics::try_merge_iter(items, &schema).unwrap(); + + assert_eq!(summary_stats.num_rows, Precision::Inexact(30)); + assert_eq!(summary_stats.total_byte_size, Precision::Inexact(300)); + + let col_stats = &summary_stats.column_statistics[0]; + assert_eq!(col_stats.null_count, Precision::Inexact(3)); + assert_eq!( + col_stats.max_value, + Precision::Inexact(ScalarValue::Int32(Some(200))) + ); + assert_eq!( + col_stats.min_value, + Precision::Inexact(ScalarValue::Int32(Some(-5))) + ); + assert_eq!( + col_stats.sum_value, + Precision::Inexact(ScalarValue::Int32(Some(1500))) + ); + } } diff --git a/datafusion/common/src/utils/aggregate.rs b/datafusion/common/src/utils/aggregate.rs new file mode 100644 index 0000000000000..d72f2ffba9eca --- /dev/null +++ b/datafusion/common/src/utils/aggregate.rs @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Array-level and scalar-level aggregation utilities (sum, min, max). +//! +//! These functions compute aggregate values over Arrow arrays or slices of +//! [`ScalarValue`]s, returning the result as a [`ScalarValue`]. They use +//! native Arrow compute kernels for primitive numeric types and fall back +//! to element-wise [`ScalarValue`] operations for other types. + +use arrow::array::ArrayRef; + +use crate::stats::Precision; +use crate::{Result, ScalarValue}; +use arrow::array::*; +use arrow::compute::{min as arrow_min, sum as arrow_sum}; +use arrow::datatypes::*; + +/// Returns true if the [`ScalarValue`] is a primitive numeric type that +/// benefits from Arrow vectorized kernels over direct `ScalarValue` +/// comparison. +pub(crate) fn is_primitive_scalar(value: &ScalarValue) -> bool { + matches!( + value, + | ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Float16(_) + | ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Decimal32(_, _, _) + | ScalarValue::Decimal64(_, _, _) + | ScalarValue::Decimal128(_, _, _) + | ScalarValue::Decimal256(_, _, _) + | ScalarValue::Date32(_) + | ScalarValue::Date64(_) + ) +} + +/// Compute the minimum of [`ScalarValue`]s using direct `PartialOrd` +/// comparison. +pub(crate) fn scalar_min(values: &[ScalarValue]) -> ScalarValue { + debug_assert!(!values.is_empty()); + let mut result = values[0].clone(); + for v in &values[1..] { + if v.is_null() { + continue; + } + if result.is_null() || v < &result { + result = v.clone(); + } + } + result +} + +/// Compute the maximum of [`ScalarValue`]s using direct `PartialOrd` +/// comparison. +pub(crate) fn scalar_max(values: &[ScalarValue]) -> ScalarValue { + debug_assert!(!values.is_empty()); + let mut result = values[0].clone(); + for v in &values[1..] { + if v.is_null() { + continue; + } + if result.is_null() || v > &result { + result = v.clone(); + } + } + result +} + +/// Wrap a [`ScalarValue`] result with the appropriate [`Precision`] level. +pub(crate) fn wrap_precision( + value: ScalarValue, + all_exact: bool, +) -> Precision { + if value.is_null() { + return Precision::Absent; + } + if all_exact { + Precision::Exact(value) + } else { + Precision::Inexact(value) + } +} + +/// Compute the sum of a collection of [`ScalarValue`]s using vectorized +/// Arrow kernels. The values are converted to an Arrow array once, then +/// the sum kernel is applied in a single call. +pub(crate) fn vectorized_sum( + values: &[ScalarValue], + all_exact: bool, +) -> Result> { + debug_assert!(!values.is_empty()); + + let array = ScalarValue::iter_to_array(values.iter().cloned())?; + let result = sum_array(&array)?; + + Ok(wrap_precision(result, all_exact)) +} + +/// Compute minimum of a collection of [`ScalarValue`]s with arrows vector kernels. +pub(crate) fn vectorized_min( + values: &[ScalarValue], + all_exact: bool, +) -> Result> { + debug_assert!(!values.is_empty()); + + let result = if is_primitive_scalar(&values[0]) { + let array = ScalarValue::iter_to_array(values.iter().cloned())?; + min_array(&array)? + } else { + scalar_min(values) + }; + + Ok(wrap_precision(result, all_exact)) +} + +/// Compute the maximum of a collection of [`ScalarValue`]s with vector kernels. +pub(crate) fn vectorized_max( + values: &[ScalarValue], + all_exact: bool, +) -> Result> { + debug_assert!(!values.is_empty()); + + let result = if is_primitive_scalar(&values[0]) { + let array = ScalarValue::iter_to_array(values.iter().cloned())?; + max_array(&array)? + } else { + scalar_max(values) + }; + + Ok(wrap_precision(result, all_exact)) +} + +/// Compute sum of all elements in an Arrow array, returning the result as +/// a [`ScalarValue`]. +pub(crate) fn sum_array(array: &ArrayRef) -> Result { + macro_rules! sum_primitive { + ($array:expr, $array_type:ty, $scalar_variant:ident $(, $extra:expr)*) => {{ + let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); + match arrow_sum(typed) { + Some(v) => ScalarValue::$scalar_variant(Some(v) $(, $extra)*), + None => ScalarValue::try_from(array.data_type())?, + } + }}; + } + + let result = match array.data_type() { + DataType::Int8 => sum_primitive!(array, Int8Array, Int8), + DataType::Int16 => sum_primitive!(array, Int16Array, Int16), + DataType::Int32 => sum_primitive!(array, Int32Array, Int32), + DataType::Int64 => sum_primitive!(array, Int64Array, Int64), + DataType::UInt8 => sum_primitive!(array, UInt8Array, UInt8), + DataType::UInt16 => sum_primitive!(array, UInt16Array, UInt16), + DataType::UInt32 => sum_primitive!(array, UInt32Array, UInt32), + DataType::UInt64 => sum_primitive!(array, UInt64Array, UInt64), + DataType::Float16 => sum_primitive!(array, Float16Array, Float16), + DataType::Float32 => sum_primitive!(array, Float32Array, Float32), + DataType::Float64 => sum_primitive!(array, Float64Array, Float64), + DataType::Decimal32(p, s) => { + let p = *p; + let s = *s; + sum_primitive!(array, Decimal32Array, Decimal32, p, s) + } + DataType::Decimal64(p, s) => { + let p = *p; + let s = *s; + sum_primitive!(array, Decimal64Array, Decimal64, p, s) + } + DataType::Decimal128(p, s) => { + let p = *p; + let s = *s; + sum_primitive!(array, Decimal128Array, Decimal128, p, s) + } + DataType::Decimal256(p, s) => { + let p = *p; + let s = *s; + sum_primitive!(array, Decimal256Array, Decimal256, p, s) + } + _ => { + let mut acc = ScalarValue::try_from_array(array, 0)?; + for i in 1..array.len() { + let v = ScalarValue::try_from_array(array, i)?; + if !v.is_null() { + if acc.is_null() { + acc = v; + } else { + acc = acc.add(&v)?; + } + } + } + acc + } + }; + + Ok(result) +} + +/// Compute the minimum of all elements in an Arrow array, returning the +/// result as a [`ScalarValue`]. +pub(crate) fn min_array(array: &ArrayRef) -> Result { + macro_rules! min_primitive { + ($array:expr, $array_type:ty, $scalar_variant:ident $(, $extra:expr)*) => {{ + let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); + match arrow_min(typed) { + Some(v) => ScalarValue::$scalar_variant(Some(v) $(, $extra)*), + None => ScalarValue::try_from(array.data_type())?, + } + }}; + } + + macro_rules! min_string { + ($array:expr, $array_type:ty, $scalar_variant:ident) => {{ + let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); + match arrow::compute::min_string(typed) { + Some(v) => ScalarValue::$scalar_variant(Some(v.to_owned())), + None => ScalarValue::try_from(array.data_type())?, + } + }}; + } + + let result = match array.data_type() { + DataType::Int8 => min_primitive!(array, Int8Array, Int8), + DataType::Int16 => min_primitive!(array, Int16Array, Int16), + DataType::Int32 => min_primitive!(array, Int32Array, Int32), + DataType::Int64 => min_primitive!(array, Int64Array, Int64), + DataType::UInt8 => min_primitive!(array, UInt8Array, UInt8), + DataType::UInt16 => min_primitive!(array, UInt16Array, UInt16), + DataType::UInt32 => min_primitive!(array, UInt32Array, UInt32), + DataType::UInt64 => min_primitive!(array, UInt64Array, UInt64), + DataType::Float16 => min_primitive!(array, Float16Array, Float16), + DataType::Float32 => min_primitive!(array, Float32Array, Float32), + DataType::Float64 => min_primitive!(array, Float64Array, Float64), + DataType::Decimal32(p, s) => { + let p = *p; + let s = *s; + min_primitive!(array, Decimal32Array, Decimal32, p, s) + } + DataType::Decimal64(p, s) => { + let p = *p; + let s = *s; + min_primitive!(array, Decimal64Array, Decimal64, p, s) + } + DataType::Decimal128(p, s) => { + let p = *p; + let s = *s; + min_primitive!(array, Decimal128Array, Decimal128, p, s) + } + DataType::Decimal256(p, s) => { + let p = *p; + let s = *s; + min_primitive!(array, Decimal256Array, Decimal256, p, s) + } + DataType::Date32 => min_primitive!(array, Date32Array, Date32), + DataType::Date64 => min_primitive!(array, Date64Array, Date64), + DataType::Utf8 => min_string!(array, StringArray, Utf8), + DataType::LargeUtf8 => min_string!(array, LargeStringArray, LargeUtf8), + DataType::Utf8View => { + let typed = array.as_any().downcast_ref::().unwrap(); + match arrow::compute::min_string_view(typed) { + Some(v) => ScalarValue::Utf8View(Some(v.to_owned())), + None => ScalarValue::try_from(array.data_type())?, + } + } + DataType::Boolean => { + let typed = array.as_any().downcast_ref::().unwrap(); + match arrow::compute::min_boolean(typed) { + Some(v) => ScalarValue::Boolean(Some(v)), + None => ScalarValue::try_from(array.data_type())?, + } + } + _ => scalar_min_max_fallback(array, true)?, + }; + + Ok(result) +} + +/// Compute the maximum of all elements in an Arrow array, returning the +/// result as a [`ScalarValue`]. +/// +/// Uses native Arrow aggregate kernels for primitive numeric types, and +/// falls back to [`ScalarValue`] comparison for other types. +pub(crate) fn max_array(array: &ArrayRef) -> Result { + use arrow::array::*; + use arrow::compute::max as arrow_max; + use arrow::datatypes::*; + + macro_rules! max_primitive { + ($array:expr, $array_type:ty, $scalar_variant:ident $(, $extra:expr)*) => {{ + let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); + match arrow_max(typed) { + Some(v) => ScalarValue::$scalar_variant(Some(v) $(, $extra)*), + None => ScalarValue::try_from(array.data_type())?, + } + }}; + } + + macro_rules! max_string { + ($array:expr, $array_type:ty, $scalar_variant:ident) => {{ + let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); + match arrow::compute::max_string(typed) { + Some(v) => ScalarValue::$scalar_variant(Some(v.to_owned())), + None => ScalarValue::try_from(array.data_type())?, + } + }}; + } + + let result = match array.data_type() { + DataType::Int8 => max_primitive!(array, Int8Array, Int8), + DataType::Int16 => max_primitive!(array, Int16Array, Int16), + DataType::Int32 => max_primitive!(array, Int32Array, Int32), + DataType::Int64 => max_primitive!(array, Int64Array, Int64), + DataType::UInt8 => max_primitive!(array, UInt8Array, UInt8), + DataType::UInt16 => max_primitive!(array, UInt16Array, UInt16), + DataType::UInt32 => max_primitive!(array, UInt32Array, UInt32), + DataType::UInt64 => max_primitive!(array, UInt64Array, UInt64), + DataType::Float16 => max_primitive!(array, Float16Array, Float16), + DataType::Float32 => max_primitive!(array, Float32Array, Float32), + DataType::Float64 => max_primitive!(array, Float64Array, Float64), + DataType::Decimal32(p, s) => { + let p = *p; + let s = *s; + max_primitive!(array, Decimal32Array, Decimal32, p, s) + } + DataType::Decimal64(p, s) => { + let p = *p; + let s = *s; + max_primitive!(array, Decimal64Array, Decimal64, p, s) + } + DataType::Decimal128(p, s) => { + let p = *p; + let s = *s; + max_primitive!(array, Decimal128Array, Decimal128, p, s) + } + DataType::Decimal256(p, s) => { + let p = *p; + let s = *s; + max_primitive!(array, Decimal256Array, Decimal256, p, s) + } + DataType::Date32 => max_primitive!(array, Date32Array, Date32), + DataType::Date64 => max_primitive!(array, Date64Array, Date64), + DataType::Utf8 => max_string!(array, StringArray, Utf8), + DataType::LargeUtf8 => max_string!(array, LargeStringArray, LargeUtf8), + DataType::Utf8View => { + let typed = array.as_any().downcast_ref::().unwrap(); + match arrow::compute::max_string_view(typed) { + Some(v) => ScalarValue::Utf8View(Some(v.to_owned())), + None => ScalarValue::try_from(array.data_type())?, + } + } + DataType::Boolean => { + let typed = array.as_any().downcast_ref::().unwrap(); + match arrow::compute::max_boolean(typed) { + Some(v) => ScalarValue::Boolean(Some(v)), + None => ScalarValue::try_from(array.data_type())?, + } + } + _ => scalar_min_max_fallback(array, false)?, + }; + + Ok(result) +} diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 7f2d78d57970e..075a189c371dc 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -17,6 +17,7 @@ //! This module provides the bisect function, which implements binary search. +pub(crate) mod aggregate; pub mod expr; pub mod memory; pub mod proxy; From 010ee19b9c1364d89d76bf62e29faaf79432f558 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 6 Mar 2026 19:56:32 -0600 Subject: [PATCH 2/6] int --- datafusion/common/src/utils/aggregate.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/common/src/utils/aggregate.rs b/datafusion/common/src/utils/aggregate.rs index d72f2ffba9eca..095fdd7b01cc8 100644 --- a/datafusion/common/src/utils/aggregate.rs +++ b/datafusion/common/src/utils/aggregate.rs @@ -36,6 +36,8 @@ use arrow::datatypes::*; pub(crate) fn is_primitive_scalar(value: &ScalarValue) -> bool { matches!( value, + ScalarValue::Int8(_) + | ScalarValue::Int16(_) | ScalarValue::Int32(_) | ScalarValue::Int64(_) | ScalarValue::UInt8(_) From 6df99db357eeadff06d095084fc65a59c769a33c Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Fri, 6 Mar 2026 19:56:53 -0600 Subject: [PATCH 3/6] add fallback --- datafusion/common/src/utils/aggregate.rs | 27 ++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/datafusion/common/src/utils/aggregate.rs b/datafusion/common/src/utils/aggregate.rs index 095fdd7b01cc8..25c7c04465794 100644 --- a/datafusion/common/src/utils/aggregate.rs +++ b/datafusion/common/src/utils/aggregate.rs @@ -380,3 +380,30 @@ pub(crate) fn max_array(array: &ArrayRef) -> Result { Ok(result) } + +fn scalar_min_max_fallback(array: &ArrayRef, is_min: bool) -> Result { + if array.len() == array.null_count() { + return ScalarValue::try_from(array.data_type()); + } + + let mut result = ScalarValue::try_from_array(array, 0)?; + for i in 1..array.len() { + let current = ScalarValue::try_from_array(array, i)?; + if current.is_null() { + continue; + } + if result.is_null() { + result = current; + continue; + } + if is_min { + if current < result { + result = current; + } + } else if current > result { + result = current; + } + } + + Ok(result) +} From 1ad9ca7fdd39839fcc0b3121d098979665482f5e Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sat, 7 Mar 2026 14:38:21 -0600 Subject: [PATCH 4/6] optimize accumulation --- datafusion/common/src/utils/aggregate.rs | 422 +++++++---------------- 1 file changed, 129 insertions(+), 293 deletions(-) diff --git a/datafusion/common/src/utils/aggregate.rs b/datafusion/common/src/utils/aggregate.rs index 25c7c04465794..09826afac6c1f 100644 --- a/datafusion/common/src/utils/aggregate.rs +++ b/datafusion/common/src/utils/aggregate.rs @@ -15,24 +15,18 @@ // specific language governing permissions and limitations // under the License. -//! Array-level and scalar-level aggregation utilities (sum, min, max). +//! Scalar-level aggregation utilities (sum, min, max). //! -//! These functions compute aggregate values over Arrow arrays or slices of -//! [`ScalarValue`]s, returning the result as a [`ScalarValue`]. They use -//! native Arrow compute kernels for primitive numeric types and fall back -//! to element-wise [`ScalarValue`] operations for other types. +//! These functions compute aggregate values over slices of [`ScalarValue`]s +//! by directly extracting the inner primitive values and accumulating. -use arrow::array::ArrayRef; +use arrow::datatypes::i256; +use half::f16; use crate::stats::Precision; use crate::{Result, ScalarValue}; -use arrow::array::*; -use arrow::compute::{min as arrow_min, sum as arrow_sum}; -use arrow::datatypes::*; -/// Returns true if the [`ScalarValue`] is a primitive numeric type that -/// benefits from Arrow vectorized kernels over direct `ScalarValue` -/// comparison. +/// Returns true if the [`ScalarValue`] is a primitive numeric type. pub(crate) fn is_primitive_scalar(value: &ScalarValue) -> bool { matches!( value, @@ -88,6 +82,120 @@ pub(crate) fn scalar_max(values: &[ScalarValue]) -> ScalarValue { result } +/// Compute the sum of [`ScalarValue`]s by directly extracting and +/// accumulating primitive values without any Arrow array allocation. +/// +/// For non-primitive types, falls back to `ScalarValue::add`. +pub(crate) fn scalar_sum(values: &[ScalarValue]) -> Result { + debug_assert!(!values.is_empty()); + + macro_rules! sum_wrapping { + ($values:expr, $VARIANT:ident, $T:ty) => {{ + let mut has_value = false; + let mut acc: $T = Default::default(); + for sv in $values { + if let ScalarValue::$VARIANT(Some(v)) = sv { + if has_value { + acc = acc.wrapping_add(*v); + } else { + acc = *v; + has_value = true; + } + } + } + if has_value { + ScalarValue::$VARIANT(Some(acc)) + } else { + $values[0].clone() // all null + } + }}; + } + + /// Accumulate a wrapping sum for a decimal ScalarValue variant that + /// carries precision and scale fields. + macro_rules! sum_decimal { + ($values:expr, $VARIANT:ident, $T:ty) => {{ + let (p, s) = match &$values[0] { + ScalarValue::$VARIANT(_, p, s) => (*p, *s), + _ => unreachable!(), + }; + let mut has_value = false; + let mut acc: $T = Default::default(); + for sv in $values { + if let ScalarValue::$VARIANT(Some(v), _, _) = sv { + if has_value { + acc = acc.wrapping_add(*v); + } else { + acc = *v; + has_value = true; + } + } + } + if has_value { + ScalarValue::$VARIANT(Some(acc), p, s) + } else { + $values[0].clone() // all null + } + }}; + } + + macro_rules! sum_float { + ($values:expr, $VARIANT:ident, $T:ty) => {{ + let mut has_value = false; + let mut acc: $T = Default::default(); + for sv in $values { + if let ScalarValue::$VARIANT(Some(v)) = sv { + if has_value { + acc = acc + *v; + } else { + acc = *v; + has_value = true; + } + } + } + if has_value { + ScalarValue::$VARIANT(Some(acc)) + } else { + $values[0].clone() // all null + } + }}; + } + + let result = match &values[0] { + ScalarValue::Int8(_) => sum_wrapping!(values, Int8, i8), + ScalarValue::Int16(_) => sum_wrapping!(values, Int16, i16), + ScalarValue::Int32(_) => sum_wrapping!(values, Int32, i32), + ScalarValue::Int64(_) => sum_wrapping!(values, Int64, i64), + ScalarValue::UInt8(_) => sum_wrapping!(values, UInt8, u8), + ScalarValue::UInt16(_) => sum_wrapping!(values, UInt16, u16), + ScalarValue::UInt32(_) => sum_wrapping!(values, UInt32, u32), + ScalarValue::UInt64(_) => sum_wrapping!(values, UInt64, u64), + ScalarValue::Float16(_) => sum_float!(values, Float16, f16), + ScalarValue::Float32(_) => sum_float!(values, Float32, f32), + ScalarValue::Float64(_) => sum_float!(values, Float64, f64), + ScalarValue::Decimal32(_, _, _) => sum_decimal!(values, Decimal32, i32), + ScalarValue::Decimal64(_, _, _) => sum_decimal!(values, Decimal64, i64), + ScalarValue::Decimal128(_, _, _) => sum_decimal!(values, Decimal128, i128), + ScalarValue::Decimal256(_, _, _) => sum_decimal!(values, Decimal256, i256), + _ => { + // Fallback for non-primitive types: use ScalarValue::add + let mut acc = values[0].clone(); + for v in &values[1..] { + if !v.is_null() { + if acc.is_null() { + acc = v.clone(); + } else { + acc = acc.add(v)?; + } + } + } + acc + } + }; + + Ok(result) +} + /// Wrap a [`ScalarValue`] result with the appropriate [`Precision`] level. pub(crate) fn wrap_precision( value: ScalarValue, @@ -103,307 +211,35 @@ pub(crate) fn wrap_precision( } } -/// Compute the sum of a collection of [`ScalarValue`]s using vectorized -/// Arrow kernels. The values are converted to an Arrow array once, then -/// the sum kernel is applied in a single call. +/// Compute the sum of a collection of [`ScalarValue`]s by directly +/// extracting primitive values and accumulating without array allocation. pub(crate) fn vectorized_sum( values: &[ScalarValue], all_exact: bool, ) -> Result> { debug_assert!(!values.is_empty()); - - let array = ScalarValue::iter_to_array(values.iter().cloned())?; - let result = sum_array(&array)?; - + let result = scalar_sum(values)?; Ok(wrap_precision(result, all_exact)) } -/// Compute minimum of a collection of [`ScalarValue`]s with arrows vector kernels. +/// Compute minimum of a collection of [`ScalarValue`]s using direct +/// `PartialOrd` comparison. pub(crate) fn vectorized_min( values: &[ScalarValue], all_exact: bool, ) -> Result> { debug_assert!(!values.is_empty()); - - let result = if is_primitive_scalar(&values[0]) { - let array = ScalarValue::iter_to_array(values.iter().cloned())?; - min_array(&array)? - } else { - scalar_min(values) - }; - + let result = scalar_min(values); Ok(wrap_precision(result, all_exact)) } -/// Compute the maximum of a collection of [`ScalarValue`]s with vector kernels. +/// Compute the maximum of a collection of [`ScalarValue`]s using direct +/// `PartialOrd` comparison. pub(crate) fn vectorized_max( values: &[ScalarValue], all_exact: bool, ) -> Result> { debug_assert!(!values.is_empty()); - - let result = if is_primitive_scalar(&values[0]) { - let array = ScalarValue::iter_to_array(values.iter().cloned())?; - max_array(&array)? - } else { - scalar_max(values) - }; - + let result = scalar_max(values); Ok(wrap_precision(result, all_exact)) } - -/// Compute sum of all elements in an Arrow array, returning the result as -/// a [`ScalarValue`]. -pub(crate) fn sum_array(array: &ArrayRef) -> Result { - macro_rules! sum_primitive { - ($array:expr, $array_type:ty, $scalar_variant:ident $(, $extra:expr)*) => {{ - let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); - match arrow_sum(typed) { - Some(v) => ScalarValue::$scalar_variant(Some(v) $(, $extra)*), - None => ScalarValue::try_from(array.data_type())?, - } - }}; - } - - let result = match array.data_type() { - DataType::Int8 => sum_primitive!(array, Int8Array, Int8), - DataType::Int16 => sum_primitive!(array, Int16Array, Int16), - DataType::Int32 => sum_primitive!(array, Int32Array, Int32), - DataType::Int64 => sum_primitive!(array, Int64Array, Int64), - DataType::UInt8 => sum_primitive!(array, UInt8Array, UInt8), - DataType::UInt16 => sum_primitive!(array, UInt16Array, UInt16), - DataType::UInt32 => sum_primitive!(array, UInt32Array, UInt32), - DataType::UInt64 => sum_primitive!(array, UInt64Array, UInt64), - DataType::Float16 => sum_primitive!(array, Float16Array, Float16), - DataType::Float32 => sum_primitive!(array, Float32Array, Float32), - DataType::Float64 => sum_primitive!(array, Float64Array, Float64), - DataType::Decimal32(p, s) => { - let p = *p; - let s = *s; - sum_primitive!(array, Decimal32Array, Decimal32, p, s) - } - DataType::Decimal64(p, s) => { - let p = *p; - let s = *s; - sum_primitive!(array, Decimal64Array, Decimal64, p, s) - } - DataType::Decimal128(p, s) => { - let p = *p; - let s = *s; - sum_primitive!(array, Decimal128Array, Decimal128, p, s) - } - DataType::Decimal256(p, s) => { - let p = *p; - let s = *s; - sum_primitive!(array, Decimal256Array, Decimal256, p, s) - } - _ => { - let mut acc = ScalarValue::try_from_array(array, 0)?; - for i in 1..array.len() { - let v = ScalarValue::try_from_array(array, i)?; - if !v.is_null() { - if acc.is_null() { - acc = v; - } else { - acc = acc.add(&v)?; - } - } - } - acc - } - }; - - Ok(result) -} - -/// Compute the minimum of all elements in an Arrow array, returning the -/// result as a [`ScalarValue`]. -pub(crate) fn min_array(array: &ArrayRef) -> Result { - macro_rules! min_primitive { - ($array:expr, $array_type:ty, $scalar_variant:ident $(, $extra:expr)*) => {{ - let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); - match arrow_min(typed) { - Some(v) => ScalarValue::$scalar_variant(Some(v) $(, $extra)*), - None => ScalarValue::try_from(array.data_type())?, - } - }}; - } - - macro_rules! min_string { - ($array:expr, $array_type:ty, $scalar_variant:ident) => {{ - let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); - match arrow::compute::min_string(typed) { - Some(v) => ScalarValue::$scalar_variant(Some(v.to_owned())), - None => ScalarValue::try_from(array.data_type())?, - } - }}; - } - - let result = match array.data_type() { - DataType::Int8 => min_primitive!(array, Int8Array, Int8), - DataType::Int16 => min_primitive!(array, Int16Array, Int16), - DataType::Int32 => min_primitive!(array, Int32Array, Int32), - DataType::Int64 => min_primitive!(array, Int64Array, Int64), - DataType::UInt8 => min_primitive!(array, UInt8Array, UInt8), - DataType::UInt16 => min_primitive!(array, UInt16Array, UInt16), - DataType::UInt32 => min_primitive!(array, UInt32Array, UInt32), - DataType::UInt64 => min_primitive!(array, UInt64Array, UInt64), - DataType::Float16 => min_primitive!(array, Float16Array, Float16), - DataType::Float32 => min_primitive!(array, Float32Array, Float32), - DataType::Float64 => min_primitive!(array, Float64Array, Float64), - DataType::Decimal32(p, s) => { - let p = *p; - let s = *s; - min_primitive!(array, Decimal32Array, Decimal32, p, s) - } - DataType::Decimal64(p, s) => { - let p = *p; - let s = *s; - min_primitive!(array, Decimal64Array, Decimal64, p, s) - } - DataType::Decimal128(p, s) => { - let p = *p; - let s = *s; - min_primitive!(array, Decimal128Array, Decimal128, p, s) - } - DataType::Decimal256(p, s) => { - let p = *p; - let s = *s; - min_primitive!(array, Decimal256Array, Decimal256, p, s) - } - DataType::Date32 => min_primitive!(array, Date32Array, Date32), - DataType::Date64 => min_primitive!(array, Date64Array, Date64), - DataType::Utf8 => min_string!(array, StringArray, Utf8), - DataType::LargeUtf8 => min_string!(array, LargeStringArray, LargeUtf8), - DataType::Utf8View => { - let typed = array.as_any().downcast_ref::().unwrap(); - match arrow::compute::min_string_view(typed) { - Some(v) => ScalarValue::Utf8View(Some(v.to_owned())), - None => ScalarValue::try_from(array.data_type())?, - } - } - DataType::Boolean => { - let typed = array.as_any().downcast_ref::().unwrap(); - match arrow::compute::min_boolean(typed) { - Some(v) => ScalarValue::Boolean(Some(v)), - None => ScalarValue::try_from(array.data_type())?, - } - } - _ => scalar_min_max_fallback(array, true)?, - }; - - Ok(result) -} - -/// Compute the maximum of all elements in an Arrow array, returning the -/// result as a [`ScalarValue`]. -/// -/// Uses native Arrow aggregate kernels for primitive numeric types, and -/// falls back to [`ScalarValue`] comparison for other types. -pub(crate) fn max_array(array: &ArrayRef) -> Result { - use arrow::array::*; - use arrow::compute::max as arrow_max; - use arrow::datatypes::*; - - macro_rules! max_primitive { - ($array:expr, $array_type:ty, $scalar_variant:ident $(, $extra:expr)*) => {{ - let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); - match arrow_max(typed) { - Some(v) => ScalarValue::$scalar_variant(Some(v) $(, $extra)*), - None => ScalarValue::try_from(array.data_type())?, - } - }}; - } - - macro_rules! max_string { - ($array:expr, $array_type:ty, $scalar_variant:ident) => {{ - let typed = $array.as_any().downcast_ref::<$array_type>().unwrap(); - match arrow::compute::max_string(typed) { - Some(v) => ScalarValue::$scalar_variant(Some(v.to_owned())), - None => ScalarValue::try_from(array.data_type())?, - } - }}; - } - - let result = match array.data_type() { - DataType::Int8 => max_primitive!(array, Int8Array, Int8), - DataType::Int16 => max_primitive!(array, Int16Array, Int16), - DataType::Int32 => max_primitive!(array, Int32Array, Int32), - DataType::Int64 => max_primitive!(array, Int64Array, Int64), - DataType::UInt8 => max_primitive!(array, UInt8Array, UInt8), - DataType::UInt16 => max_primitive!(array, UInt16Array, UInt16), - DataType::UInt32 => max_primitive!(array, UInt32Array, UInt32), - DataType::UInt64 => max_primitive!(array, UInt64Array, UInt64), - DataType::Float16 => max_primitive!(array, Float16Array, Float16), - DataType::Float32 => max_primitive!(array, Float32Array, Float32), - DataType::Float64 => max_primitive!(array, Float64Array, Float64), - DataType::Decimal32(p, s) => { - let p = *p; - let s = *s; - max_primitive!(array, Decimal32Array, Decimal32, p, s) - } - DataType::Decimal64(p, s) => { - let p = *p; - let s = *s; - max_primitive!(array, Decimal64Array, Decimal64, p, s) - } - DataType::Decimal128(p, s) => { - let p = *p; - let s = *s; - max_primitive!(array, Decimal128Array, Decimal128, p, s) - } - DataType::Decimal256(p, s) => { - let p = *p; - let s = *s; - max_primitive!(array, Decimal256Array, Decimal256, p, s) - } - DataType::Date32 => max_primitive!(array, Date32Array, Date32), - DataType::Date64 => max_primitive!(array, Date64Array, Date64), - DataType::Utf8 => max_string!(array, StringArray, Utf8), - DataType::LargeUtf8 => max_string!(array, LargeStringArray, LargeUtf8), - DataType::Utf8View => { - let typed = array.as_any().downcast_ref::().unwrap(); - match arrow::compute::max_string_view(typed) { - Some(v) => ScalarValue::Utf8View(Some(v.to_owned())), - None => ScalarValue::try_from(array.data_type())?, - } - } - DataType::Boolean => { - let typed = array.as_any().downcast_ref::().unwrap(); - match arrow::compute::max_boolean(typed) { - Some(v) => ScalarValue::Boolean(Some(v)), - None => ScalarValue::try_from(array.data_type())?, - } - } - _ => scalar_min_max_fallback(array, false)?, - }; - - Ok(result) -} - -fn scalar_min_max_fallback(array: &ArrayRef, is_min: bool) -> Result { - if array.len() == array.null_count() { - return ScalarValue::try_from(array.data_type()); - } - - let mut result = ScalarValue::try_from_array(array, 0)?; - for i in 1..array.len() { - let current = ScalarValue::try_from_array(array, i)?; - if current.is_null() { - continue; - } - if result.is_null() { - result = current; - continue; - } - if is_min { - if current < result { - result = current; - } - } else if current > result { - result = current; - } - } - - Ok(result) -} From e51023e644e82724b65ffa0e251d53a3f5a5d2ff Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sat, 7 Mar 2026 15:23:52 -0600 Subject: [PATCH 5/6] clippy: use += instead of acc = acc + *v --- datafusion/common/src/utils/aggregate.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/common/src/utils/aggregate.rs b/datafusion/common/src/utils/aggregate.rs index 09826afac6c1f..9a4f380beaeff 100644 --- a/datafusion/common/src/utils/aggregate.rs +++ b/datafusion/common/src/utils/aggregate.rs @@ -146,7 +146,7 @@ pub(crate) fn scalar_sum(values: &[ScalarValue]) -> Result { for sv in $values { if let ScalarValue::$VARIANT(Some(v)) = sv { if has_value { - acc = acc + *v; + acc += *v; } else { acc = *v; has_value = true; From c1724a2d57729fd74d4777100a463504d86600c0 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Sat, 7 Mar 2026 15:37:47 -0600 Subject: [PATCH 6/6] simplify: single-loop accumulation with cheap scalar_add --- datafusion/common/src/stats.rs | 114 +-------- datafusion/common/src/utils/aggregate.rs | 280 ++++++----------------- 2 files changed, 85 insertions(+), 309 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index bc799b778593b..0bdfe2d2ce372 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -22,9 +22,7 @@ use std::fmt::{self, Debug, Display}; use crate::{Result, ScalarValue}; use crate::error::_plan_err; -use crate::utils::aggregate::{ - is_primitive_scalar, vectorized_max, vectorized_min, vectorized_sum, -}; +use crate::utils::aggregate::precision_add; use arrow::datatypes::{DataType, Schema}; /// Represents a value with a degree of certainty. `Precision` is used to @@ -651,23 +649,6 @@ impl Statistics { total_byte_size = total_byte_size.add(&stat.total_byte_size); } - // We look at the first available value type for each column to decide whether to - // use vectorized aggregation - let col_is_primitive: Vec = - (0..num_cols) - .map(|col_idx| { - items - .first() - .and_then(|s| { - s.column_statistics[col_idx].sum_value.get_value().or_else( - || s.column_statistics[col_idx].min_value.get_value(), - ) - }) - .map(is_primitive_scalar) - .unwrap_or(false) - }) - .collect(); - let first = items[0]; let mut column_statistics: Vec = first .column_statistics @@ -682,98 +663,21 @@ impl Statistics { }) .collect(); - // fold non-primitive columns and accumulate usize - // stats for all columns. + // Accumulate all statistics in a single pass. + // Uses precision_add for sum (avoids the expensive + // ScalarValue::add round-trip through Arrow arrays), and + // Precision::min/max which use cheap PartialOrd comparison. for stat in items.iter().skip(1) { for (col_idx, col_stats) in column_statistics.iter_mut().enumerate() { let item_cs = &stat.column_statistics[col_idx]; col_stats.null_count = col_stats.null_count.add(&item_cs.null_count); col_stats.byte_size = col_stats.byte_size.add(&item_cs.byte_size); - - // Non-primitive columns: fold sum/min/max directly using - // Precision::add/min/max - if !col_is_primitive[col_idx] { - col_stats.sum_value = col_stats.sum_value.add(&item_cs.sum_value); - col_stats.min_value = col_stats.min_value.min(&item_cs.min_value); - col_stats.max_value = col_stats.max_value.max(&item_cs.max_value); - } - } - } - - // Column-major vectorized pass: only for primitive columns. - // Collects values into Vecs and uses Arrow kernels for sum/min/max. - for col_idx in 0..num_cols { - if !col_is_primitive[col_idx] { - continue; - } - - // These booleans check whether the values were either all exact or if there - // were any inexact - let mut sum_values: Vec = Vec::with_capacity(items.len()); - let mut min_values: Vec = Vec::with_capacity(items.len()); - let mut max_values: Vec = Vec::with_capacity(items.len()); - let mut sum_all_exact = true; - let mut min_all_exact = true; - let mut max_all_exact = true; - let mut sum_any_absent = false; - let mut min_any_absent = false; - let mut max_any_absent = false; - - for stat in &items { - let cs = &stat.column_statistics[col_idx]; - - match &cs.sum_value { - Precision::Exact(v) => sum_values.push(v.clone()), - Precision::Inexact(v) => { - sum_all_exact = false; - sum_values.push(v.clone()); - } - Precision::Absent => { - sum_any_absent = true; - } - } - match &cs.min_value { - Precision::Exact(v) => min_values.push(v.clone()), - Precision::Inexact(v) => { - min_all_exact = false; - min_values.push(v.clone()); - } - Precision::Absent => { - min_any_absent = true; - } - } - match &cs.max_value { - Precision::Exact(v) => max_values.push(v.clone()), - Precision::Inexact(v) => { - max_all_exact = false; - max_values.push(v.clone()); - } - Precision::Absent => { - max_any_absent = true; - } - } + col_stats.sum_value = + precision_add(&col_stats.sum_value, &item_cs.sum_value); + col_stats.min_value = col_stats.min_value.min(&item_cs.min_value); + col_stats.max_value = col_stats.max_value.max(&item_cs.max_value); } - - let col_stats = &mut column_statistics[col_idx]; - - col_stats.sum_value = if sum_any_absent || sum_values.is_empty() { - Precision::Absent - } else { - vectorized_sum(&sum_values, sum_all_exact)? - }; - - col_stats.min_value = if min_any_absent || min_values.is_empty() { - Precision::Absent - } else { - vectorized_min(&min_values, min_all_exact)? - }; - - col_stats.max_value = if max_any_absent || max_values.is_empty() { - Precision::Absent - } else { - vectorized_max(&max_values, max_all_exact)? - }; } Ok(Statistics { diff --git a/datafusion/common/src/utils/aggregate.rs b/datafusion/common/src/utils/aggregate.rs index 9a4f380beaeff..16cc3bc033197 100644 --- a/datafusion/common/src/utils/aggregate.rs +++ b/datafusion/common/src/utils/aggregate.rs @@ -15,231 +15,103 @@ // specific language governing permissions and limitations // under the License. -//! Scalar-level aggregation utilities (sum, min, max). +//! Scalar-level aggregation utilities for statistics merging. //! -//! These functions compute aggregate values over slices of [`ScalarValue`]s -//! by directly extracting the inner primitive values and accumulating. - -use arrow::datatypes::i256; -use half::f16; +//! Provides a cheap pairwise [`ScalarValue`] addition that directly +//! extracts inner primitive values, avoiding the expensive +//! `ScalarValue::add` path (which round-trips through Arrow arrays). use crate::stats::Precision; use crate::{Result, ScalarValue}; -/// Returns true if the [`ScalarValue`] is a primitive numeric type. -pub(crate) fn is_primitive_scalar(value: &ScalarValue) -> bool { - matches!( - value, - ScalarValue::Int8(_) - | ScalarValue::Int16(_) - | ScalarValue::Int32(_) - | ScalarValue::Int64(_) - | ScalarValue::UInt8(_) - | ScalarValue::UInt16(_) - | ScalarValue::UInt32(_) - | ScalarValue::UInt64(_) - | ScalarValue::Float16(_) - | ScalarValue::Float32(_) - | ScalarValue::Float64(_) - | ScalarValue::Decimal32(_, _, _) - | ScalarValue::Decimal64(_, _, _) - | ScalarValue::Decimal128(_, _, _) - | ScalarValue::Decimal256(_, _, _) - | ScalarValue::Date32(_) - | ScalarValue::Date64(_) - ) -} - -/// Compute the minimum of [`ScalarValue`]s using direct `PartialOrd` -/// comparison. -pub(crate) fn scalar_min(values: &[ScalarValue]) -> ScalarValue { - debug_assert!(!values.is_empty()); - let mut result = values[0].clone(); - for v in &values[1..] { - if v.is_null() { - continue; - } - if result.is_null() || v < &result { - result = v.clone(); - } - } - result -} - -/// Compute the maximum of [`ScalarValue`]s using direct `PartialOrd` -/// comparison. -pub(crate) fn scalar_max(values: &[ScalarValue]) -> ScalarValue { - debug_assert!(!values.is_empty()); - let mut result = values[0].clone(); - for v in &values[1..] { - if v.is_null() { - continue; - } - if result.is_null() || v > &result { - result = v.clone(); - } - } - result -} - -/// Compute the sum of [`ScalarValue`]s by directly extracting and -/// accumulating primitive values without any Arrow array allocation. +/// Add two [`ScalarValue`]s by directly extracting and adding their +/// inner primitive values. +/// +/// This avoids `ScalarValue::add` which converts both operands to +/// single-element Arrow arrays, runs the `add_wrapping` kernel, and +/// converts the result back — 3 heap allocations per call. /// /// For non-primitive types, falls back to `ScalarValue::add`. -pub(crate) fn scalar_sum(values: &[ScalarValue]) -> Result { - debug_assert!(!values.is_empty()); - - macro_rules! sum_wrapping { - ($values:expr, $VARIANT:ident, $T:ty) => {{ - let mut has_value = false; - let mut acc: $T = Default::default(); - for sv in $values { - if let ScalarValue::$VARIANT(Some(v)) = sv { - if has_value { - acc = acc.wrapping_add(*v); - } else { - acc = *v; - has_value = true; - } +pub(crate) fn scalar_add(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { + macro_rules! add_wrapping { + ($lhs:expr, $rhs:expr, $VARIANT:ident) => { + match ($lhs, $rhs) { + (ScalarValue::$VARIANT(Some(a)), ScalarValue::$VARIANT(Some(b))) => { + Ok(ScalarValue::$VARIANT(Some(a.wrapping_add(*b)))) } + (ScalarValue::$VARIANT(None), other) + | (other, ScalarValue::$VARIANT(None)) => Ok(other.clone()), + _ => unreachable!(), } - if has_value { - ScalarValue::$VARIANT(Some(acc)) - } else { - $values[0].clone() // all null - } - }}; + }; } - /// Accumulate a wrapping sum for a decimal ScalarValue variant that - /// carries precision and scale fields. - macro_rules! sum_decimal { - ($values:expr, $VARIANT:ident, $T:ty) => {{ - let (p, s) = match &$values[0] { - ScalarValue::$VARIANT(_, p, s) => (*p, *s), + macro_rules! add_decimal { + ($lhs:expr, $rhs:expr, $VARIANT:ident) => { + match ($lhs, $rhs) { + ( + ScalarValue::$VARIANT(Some(a), p, s), + ScalarValue::$VARIANT(Some(b), _, _), + ) => Ok(ScalarValue::$VARIANT(Some(a.wrapping_add(*b)), *p, *s)), + (ScalarValue::$VARIANT(None, _, _), other) + | (other, ScalarValue::$VARIANT(None, _, _)) => Ok(other.clone()), _ => unreachable!(), - }; - let mut has_value = false; - let mut acc: $T = Default::default(); - for sv in $values { - if let ScalarValue::$VARIANT(Some(v), _, _) = sv { - if has_value { - acc = acc.wrapping_add(*v); - } else { - acc = *v; - has_value = true; - } - } - } - if has_value { - ScalarValue::$VARIANT(Some(acc), p, s) - } else { - $values[0].clone() // all null } - }}; + }; } - macro_rules! sum_float { - ($values:expr, $VARIANT:ident, $T:ty) => {{ - let mut has_value = false; - let mut acc: $T = Default::default(); - for sv in $values { - if let ScalarValue::$VARIANT(Some(v)) = sv { - if has_value { - acc += *v; - } else { - acc = *v; - has_value = true; - } + macro_rules! add_float { + ($lhs:expr, $rhs:expr, $VARIANT:ident) => { + match ($lhs, $rhs) { + (ScalarValue::$VARIANT(Some(a)), ScalarValue::$VARIANT(Some(b))) => { + Ok(ScalarValue::$VARIANT(Some(*a + *b))) } + (ScalarValue::$VARIANT(None), other) + | (other, ScalarValue::$VARIANT(None)) => Ok(other.clone()), + _ => unreachable!(), } - if has_value { - ScalarValue::$VARIANT(Some(acc)) - } else { - $values[0].clone() // all null - } - }}; + }; } - let result = match &values[0] { - ScalarValue::Int8(_) => sum_wrapping!(values, Int8, i8), - ScalarValue::Int16(_) => sum_wrapping!(values, Int16, i16), - ScalarValue::Int32(_) => sum_wrapping!(values, Int32, i32), - ScalarValue::Int64(_) => sum_wrapping!(values, Int64, i64), - ScalarValue::UInt8(_) => sum_wrapping!(values, UInt8, u8), - ScalarValue::UInt16(_) => sum_wrapping!(values, UInt16, u16), - ScalarValue::UInt32(_) => sum_wrapping!(values, UInt32, u32), - ScalarValue::UInt64(_) => sum_wrapping!(values, UInt64, u64), - ScalarValue::Float16(_) => sum_float!(values, Float16, f16), - ScalarValue::Float32(_) => sum_float!(values, Float32, f32), - ScalarValue::Float64(_) => sum_float!(values, Float64, f64), - ScalarValue::Decimal32(_, _, _) => sum_decimal!(values, Decimal32, i32), - ScalarValue::Decimal64(_, _, _) => sum_decimal!(values, Decimal64, i64), - ScalarValue::Decimal128(_, _, _) => sum_decimal!(values, Decimal128, i128), - ScalarValue::Decimal256(_, _, _) => sum_decimal!(values, Decimal256, i256), - _ => { - // Fallback for non-primitive types: use ScalarValue::add - let mut acc = values[0].clone(); - for v in &values[1..] { - if !v.is_null() { - if acc.is_null() { - acc = v.clone(); - } else { - acc = acc.add(v)?; - } - } - } - acc - } - }; - - Ok(result) + match lhs { + ScalarValue::Int8(_) => add_wrapping!(lhs, rhs, Int8), + ScalarValue::Int16(_) => add_wrapping!(lhs, rhs, Int16), + ScalarValue::Int32(_) => add_wrapping!(lhs, rhs, Int32), + ScalarValue::Int64(_) => add_wrapping!(lhs, rhs, Int64), + ScalarValue::UInt8(_) => add_wrapping!(lhs, rhs, UInt8), + ScalarValue::UInt16(_) => add_wrapping!(lhs, rhs, UInt16), + ScalarValue::UInt32(_) => add_wrapping!(lhs, rhs, UInt32), + ScalarValue::UInt64(_) => add_wrapping!(lhs, rhs, UInt64), + ScalarValue::Float16(_) => add_float!(lhs, rhs, Float16), + ScalarValue::Float32(_) => add_float!(lhs, rhs, Float32), + ScalarValue::Float64(_) => add_float!(lhs, rhs, Float64), + ScalarValue::Decimal32(_, _, _) => add_decimal!(lhs, rhs, Decimal32), + ScalarValue::Decimal64(_, _, _) => add_decimal!(lhs, rhs, Decimal64), + ScalarValue::Decimal128(_, _, _) => add_decimal!(lhs, rhs, Decimal128), + ScalarValue::Decimal256(_, _, _) => add_decimal!(lhs, rhs, Decimal256), + // Fallback: use the existing ScalarValue::add + _ => lhs.add(rhs), + } } -/// Wrap a [`ScalarValue`] result with the appropriate [`Precision`] level. -pub(crate) fn wrap_precision( - value: ScalarValue, - all_exact: bool, +/// [`Precision`]-aware sum of two [`ScalarValue`] precisions using +/// cheap direct addition via [`scalar_add`]. +/// +/// Mirrors the semantics of `Precision::add` but avoids +/// the expensive `ScalarValue::add` round-trip through Arrow arrays. +pub(crate) fn precision_add( + lhs: &Precision, + rhs: &Precision, ) -> Precision { - if value.is_null() { - return Precision::Absent; + match (lhs, rhs) { + (Precision::Exact(a), Precision::Exact(b)) => scalar_add(a, b) + .map(Precision::Exact) + .unwrap_or(Precision::Absent), + (Precision::Inexact(a), Precision::Exact(b)) + | (Precision::Exact(a), Precision::Inexact(b)) + | (Precision::Inexact(a), Precision::Inexact(b)) => scalar_add(a, b) + .map(Precision::Inexact) + .unwrap_or(Precision::Absent), + (_, _) => Precision::Absent, } - if all_exact { - Precision::Exact(value) - } else { - Precision::Inexact(value) - } -} - -/// Compute the sum of a collection of [`ScalarValue`]s by directly -/// extracting primitive values and accumulating without array allocation. -pub(crate) fn vectorized_sum( - values: &[ScalarValue], - all_exact: bool, -) -> Result> { - debug_assert!(!values.is_empty()); - let result = scalar_sum(values)?; - Ok(wrap_precision(result, all_exact)) -} - -/// Compute minimum of a collection of [`ScalarValue`]s using direct -/// `PartialOrd` comparison. -pub(crate) fn vectorized_min( - values: &[ScalarValue], - all_exact: bool, -) -> Result> { - debug_assert!(!values.is_empty()); - let result = scalar_min(values); - Ok(wrap_precision(result, all_exact)) -} - -/// Compute the maximum of a collection of [`ScalarValue`]s using direct -/// `PartialOrd` comparison. -pub(crate) fn vectorized_max( - values: &[ScalarValue], - all_exact: bool, -) -> Result> { - debug_assert!(!values.is_empty()); - let result = scalar_max(values); - Ok(wrap_precision(result, all_exact)) }