diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index fa021ed3dcce3..57e7d17a9340a 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -929,7 +929,10 @@ mod test { num_rows: Precision::Exact(0), total_byte_size: Precision::Absent, column_statistics: vec![ - ColumnStatistics::new_unknown(), + ColumnStatistics { + distinct_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), ], diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b9a1e49ab5220..8d71cace37fb2 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1037,10 +1037,37 @@ impl AggregateExec { &self.input_order_mode } + /// Estimate the number of output groups from group-by column distinct counts. + /// + /// When all group-by expressions are simple column references with known + /// `distinct_count` statistics, returns the product of those distinct counts + /// as an upper-bound estimate. Returns `None` if any group-by expression + /// is not a simple column or lacks a distinct count. + fn estimate_group_count(&self, child_statistics: &Statistics) -> Option { + if self.group_by.expr.is_empty() { + return None; + } + + let mut product: usize = 1; + for (expr, _) in self.group_by.expr.iter() { + let col = expr.as_any().downcast_ref::()?; + let dc = child_statistics.column_statistics[col.index()] + .distinct_count + .get_value()?; // function returns None if value is Absent + if *dc == 0 { + // A distinct count of zero means no data or unknown — don't + // use it to collapse the entire estimate to zero. + return None; + } + product = product.saturating_mul(*dc); + } + + Some(product) + } + fn statistics_inner(&self, child_statistics: &Statistics) -> Result { // TODO stats: group expressions: // - once expressions will be able to compute their own stats, use it here - // - case where we group by on a column for which with have the `distinct` stat // TODO stats: aggr expression: // - aggregations sometimes also preserve invariants such as min, max... @@ -1050,20 +1077,19 @@ impl AggregateExec { for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() { if let Some(col) = expr.as_any().downcast_ref::() { - column_statistics[idx].max_value = child_statistics.column_statistics - [col.index()] - .max_value - .clone(); - - column_statistics[idx].min_value = child_statistics.column_statistics - [col.index()] - .min_value - .clone(); + let child_col_stats = + &child_statistics.column_statistics[col.index()]; + + column_statistics[idx].max_value = child_col_stats.max_value.clone(); + column_statistics[idx].min_value = child_col_stats.min_value.clone(); + column_statistics[idx].distinct_count = + child_col_stats.distinct_count; } } column_statistics }; + match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned if self.group_by.expr.is_empty() => @@ -1083,7 +1109,19 @@ impl AggregateExec { let num_rows = if let Some(value) = child_statistics.num_rows.get_value() { if *value > 1 { - child_statistics.num_rows.to_inexact() + // Use distinct_count statistics from group-by columns to + // estimate output cardinality. The number of output groups + // is at most the product of distinct values across all + // group-by columns, capped by the input row count. + let group_by_estimate = + self.estimate_group_count(child_statistics); + match group_by_estimate { + Some(estimated_groups) => { + let capped = estimated_groups.min(*value); + Precision::Inexact(capped) + } + None => child_statistics.num_rows.to_inexact(), + } } else if *value == 0 { child_statistics.num_rows } else { @@ -3792,6 +3830,188 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_statistics_distinct_count() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Float64, false), + ])); + + // Test 1: Single group-by column with known distinct_count + // 1000 input rows, column "a" has 10 distinct values + // => output should be Inexact(10) + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(10), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics::new_unknown(), + ], + }, + (*schema).clone(), + )) as Arc; + + let group_by_a = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + + let agg = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by_a, + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(c)") + .build()?, + )], + vec![None], + Arc::clone(&input), + Arc::clone(&schema), + )?); + + let stats = agg.partition_statistics(None)?; + assert_eq!(stats.num_rows, Precision::Inexact(10)); + // distinct_count should be propagated to group-by output column + assert_eq!( + stats.column_statistics[0].distinct_count, + Precision::Exact(10) + ); + + // Test 2: Multiple group-by columns — product of distinct counts + // GROUP BY a, b with distinct_count(a)=10, distinct_count(b)=5 + // => output should be Inexact(min(10*5, 1000)) = Inexact(50) + let group_by_ab = PhysicalGroupBy::new_single(vec![ + (col("a", &schema)?, "a".to_string()), + (col("b", &schema)?, "b".to_string()), + ]); + + let agg_ab = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by_ab, + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(c)") + .build()?, + )], + vec![None], + Arc::clone(&input), + Arc::clone(&schema), + )?); + + let stats_ab = agg_ab.partition_statistics(None)?; + assert_eq!(stats_ab.num_rows, Precision::Inexact(50)); + + // Test 3: Product exceeds input rows — capped at input rows + // 30 input rows, GROUP BY a, b with distinct_count(a)=10, distinct_count(b)=5 + // => product = 50, but capped at 30 + let input_small = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Exact(30), + total_byte_size: Precision::Exact(240), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(10), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics::new_unknown(), + ], + }, + (*schema).clone(), + )) as Arc; + + let group_by_ab2 = PhysicalGroupBy::new_single(vec![ + (col("a", &schema)?, "a".to_string()), + (col("b", &schema)?, "b".to_string()), + ]); + + let agg_small = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by_ab2, + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("c", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(c)") + .build()?, + )], + vec![None], + Arc::clone(&input_small), + Arc::clone(&schema), + )?); + + let stats_small = agg_small.partition_statistics(None)?; + assert_eq!(stats_small.num_rows, Precision::Inexact(30)); + + // Test 4: One group-by column missing distinct_count — fallback to input rows + // GROUP BY a, c where c has no distinct_count + let group_by_ac = PhysicalGroupBy::new_single(vec![ + (col("a", &schema)?, "a".to_string()), + (col("c", &schema)?, "c".to_string()), + ]); + + let agg_ac = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by_ac, + vec![], + vec![], + Arc::clone(&input), + Arc::clone(&schema), + )?); + + let stats_ac = agg_ac.partition_statistics(None)?; + // Falls back to input row count since column c has no distinct_count + assert_eq!(stats_ac.num_rows, Precision::Inexact(1000)); + + // Test 5: Inexact distinct_count is also used + let input_inexact = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Exact(8000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Inexact(20), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ], + }, + (*schema).clone(), + )) as Arc; + + let group_by_a2 = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); + + let agg_inexact = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by_a2, + vec![], + vec![], + input_inexact, + Arc::clone(&schema), + )?); + + let stats_inexact = agg_inexact.partition_statistics(None)?; + assert_eq!(stats_inexact.num_rows, Precision::Inexact(20)); + + Ok(()) + } + #[tokio::test] async fn test_order_is_retained_when_spilling() -> Result<()> { let schema = Arc::new(Schema::new(vec![