@@ -32,12 +32,16 @@ use datafusion_common::types::{
3232 logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
3333} ;
3434use datafusion_common:: { HashMap , Result , ScalarValue , exec_err, not_impl_err} ;
35- use datafusion_expr:: function:: { AccumulatorArgs , StateFieldsArgs } ;
35+ use datafusion_expr:: expr:: { AggregateFunction , AggregateFunctionParams } ;
36+ use datafusion_expr:: function:: {
37+ AccumulatorArgs , AggregateFunctionSimplification , StateFieldsArgs ,
38+ } ;
39+ use datafusion_expr:: simplify:: SimplifyContext ;
3640use datafusion_expr:: utils:: { AggregateOrderSensitivity , format_state_name} ;
3741use datafusion_expr:: {
38- Accumulator , AggregateUDFImpl , Coercion , Documentation , Expr , GroupsAccumulator ,
39- ReversedUDAF , SetMonotonicity , Signature , TypeSignature , TypeSignatureClass ,
40- Volatility ,
42+ Accumulator , AggregateUDFImpl , BinaryExpr , Coercion , Documentation , Expr ,
43+ GroupsAccumulator , Operator , ReversedUDAF , SetMonotonicity , Signature , TypeSignature ,
44+ TypeSignatureClass , Volatility ,
4145} ;
4246use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: prim_op:: PrimitiveGroupsAccumulator ;
4347use datafusion_functions_aggregate_common:: aggregate:: sum_distinct:: DistinctSumAccumulator ;
@@ -54,7 +58,7 @@ make_udaf_expr_and_func!(
5458) ;
5559
5660pub fn sum_distinct ( expr : Expr ) -> Expr {
57- Expr :: AggregateFunction ( datafusion_expr :: expr :: AggregateFunction :: new_udf (
61+ Expr :: AggregateFunction ( AggregateFunction :: new_udf (
5862 sum_udaf ( ) ,
5963 vec ! [ expr] ,
6064 true ,
@@ -346,6 +350,136 @@ impl AggregateUDFImpl for Sum {
346350 _ => SetMonotonicity :: NotMonotonic ,
347351 }
348352 }
353+
354+ /// Simplification Rules
355+ fn simplify ( & self ) -> Option < AggregateFunctionSimplification > {
356+ Some ( Box :: new ( sum_simplifier) )
357+ }
358+ }
359+
360+ /// Implement ClickBench Q29 specific optimization:
361+ /// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
362+ ///
363+ /// Backstory: TODO
364+ ///
365+ fn sum_simplifier ( mut agg : AggregateFunction , info : & SimplifyContext ) -> Result < Expr > {
366+ // Explicitly destructure to ensure we check all relevant fields
367+ let AggregateFunctionParams {
368+ args,
369+ distinct,
370+ filter,
371+ order_by,
372+ null_treatment,
373+ } = & agg. params ;
374+
375+ if * distinct
376+ || filter. is_some ( )
377+ || !order_by. is_empty ( )
378+ || null_treatment. is_some ( )
379+ || args. len ( ) != 1
380+ {
381+ return Ok ( Expr :: AggregateFunction ( agg) ) ;
382+ }
383+
384+ // otherwise check the arguments if they are <arg> + <literal>
385+ let ( arg, lit) = match SplitResult :: new ( agg. params . args [ 0 ] . clone ( ) ) {
386+ SplitResult :: Original => return Ok ( Expr :: AggregateFunction ( agg) ) ,
387+ SplitResult :: Split { arg, lit } => ( arg, lit) ,
388+ } ;
389+
390+ if !has_common_rewrite_arg ( & arg, info) {
391+ return Ok ( Expr :: AggregateFunction ( agg) ) ;
392+ }
393+
394+ // Rewrite to SUM(arg)
395+ agg. params . args = vec ! [ arg. clone( ) ] ;
396+ let sum_agg = Expr :: AggregateFunction ( agg) ;
397+
398+ // sum(arg) + scalar * COUNT(arg)
399+ Ok ( sum_agg + ( lit * crate :: count:: count ( arg) ) )
400+ }
401+
402+ fn has_common_rewrite_arg ( arg : & Expr , info : & SimplifyContext ) -> bool {
403+ let Some ( aggregate_exprs) = info. aggregate_exprs ( ) else {
404+ // Only apply this rewrite in the context of an Aggregate node where
405+ // sibling aggregate expressions are known.
406+ return false ;
407+ } ;
408+
409+ aggregate_exprs
410+ . iter ( )
411+ . filter_map ( sum_rewrite_candidate_arg)
412+ . filter ( |candidate_arg| candidate_arg == arg)
413+ . take ( 2 )
414+ . count ( )
415+ > 1
416+ }
417+
418+ fn sum_rewrite_candidate_arg ( expr : & Expr ) -> Option < Expr > {
419+ let Expr :: AggregateFunction ( aggregate_fn) = expr. clone ( ) . unalias_nested ( ) . data else {
420+ return None ;
421+ } ;
422+ if !aggregate_fn. func . name ( ) . eq_ignore_ascii_case ( "sum" ) {
423+ return None ;
424+ }
425+
426+ let AggregateFunctionParams {
427+ args,
428+ distinct,
429+ filter,
430+ order_by,
431+ null_treatment,
432+ } = & aggregate_fn. params ;
433+
434+ if * distinct
435+ || filter. is_some ( )
436+ || !order_by. is_empty ( )
437+ || null_treatment. is_some ( )
438+ || args. len ( ) != 1
439+ {
440+ return None ;
441+ }
442+
443+ match SplitResult :: new ( args[ 0 ] . clone ( ) ) {
444+ SplitResult :: Split { arg, .. } => Some ( arg) ,
445+ SplitResult :: Original => None ,
446+ }
447+ }
448+
449+ /// Result of trying to split an expression into an arg and constant
450+ #[ derive( Debug , Clone ) ]
451+ enum SplitResult {
452+ /// if the expression is either of
453+ /// * `<arg> <op> <lit>`
454+ /// * `<lit> <op> <arg>`
455+ ///
456+ /// When `op` is `+`
457+ Split { arg : Expr , lit : Expr } ,
458+ /// If the expression is something else
459+ Original ,
460+ }
461+
462+ impl SplitResult {
463+ fn new ( expr : Expr ) -> Self {
464+ let Expr :: BinaryExpr ( BinaryExpr { left, op, right } ) = expr else {
465+ return Self :: Original ;
466+ } ;
467+ if op != Operator :: Plus {
468+ return Self :: Original ;
469+ }
470+
471+ match ( left. as_ref ( ) , right. as_ref ( ) ) {
472+ ( Expr :: Literal ( ..) , _) => Self :: Split {
473+ arg : * right,
474+ lit : * left,
475+ } ,
476+ ( _, Expr :: Literal ( ..) ) => Self :: Split {
477+ arg : * left,
478+ lit : * right,
479+ } ,
480+ _ => Self :: Original ,
481+ }
482+ }
349483}
350484
351485/// This accumulator computes SUM incrementally
0 commit comments