Skip to content

Commit c048aa0

Browse files
committed
Add rewrite SUM(expr+C) --> SUM(expr) + COUNT(expr)*C
1 parent a51ba01 commit c048aa0

6 files changed

Lines changed: 331 additions & 51 deletions

File tree

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ impl Alias {
600600
}
601601
}
602602

603-
/// Binary expression
603+
/// Binary expression for [`Expr::BinaryExpr`]
604604
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
605605
pub struct BinaryExpr {
606606
/// Left-hand side of the expression

datafusion/expr/src/simplify.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub struct SimplifyContext {
3838
schema: DFSchemaRef,
3939
query_execution_start_time: Option<DateTime<Utc>>,
4040
config_options: Arc<ConfigOptions>,
41+
aggregate_exprs: Option<Arc<Vec<Expr>>>,
4142
}
4243

4344
impl Default for SimplifyContext {
@@ -46,6 +47,7 @@ impl Default for SimplifyContext {
4647
schema: Arc::new(DFSchema::empty()),
4748
query_execution_start_time: None,
4849
config_options: Arc::new(ConfigOptions::default()),
50+
aggregate_exprs: None,
4951
}
5052
}
5153
}
@@ -78,6 +80,12 @@ impl SimplifyContext {
7880
self
7981
}
8082

83+
/// Set aggregate expressions from the containing aggregate node, if any.
84+
pub fn with_aggregate_exprs(mut self, aggregate_exprs: Arc<Vec<Expr>>) -> Self {
85+
self.aggregate_exprs = Some(aggregate_exprs);
86+
self
87+
}
88+
8189
/// Returns the schema
8290
pub fn schema(&self) -> &DFSchemaRef {
8391
&self.schema
@@ -108,6 +116,11 @@ impl SimplifyContext {
108116
pub fn config_options(&self) -> &Arc<ConfigOptions> {
109117
&self.config_options
110118
}
119+
120+
/// Returns aggregate expressions from the containing aggregate node, if any.
121+
pub fn aggregate_exprs(&self) -> Option<&[Expr]> {
122+
self.aggregate_exprs.as_deref().map(Vec::as_slice)
123+
}
111124
}
112125

113126
/// Was the expression simplified?

datafusion/functions-aggregate/src/sum.rs

Lines changed: 139 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@ use datafusion_common::types::{
3232
logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
3333
};
3434
use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
35-
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
35+
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
36+
use datafusion_expr::function::{
37+
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
38+
};
39+
use datafusion_expr::simplify::SimplifyContext;
3640
use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
3741
use datafusion_expr::{
38-
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
39-
ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass,
40-
Volatility,
42+
Accumulator, AggregateUDFImpl, BinaryExpr, Coercion, Documentation, Expr,
43+
GroupsAccumulator, Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
44+
TypeSignatureClass, Volatility,
4145
};
4246
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
4347
use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
@@ -54,7 +58,7 @@ make_udaf_expr_and_func!(
5458
);
5559

5660
pub fn sum_distinct(expr: Expr) -> Expr {
57-
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
61+
Expr::AggregateFunction(AggregateFunction::new_udf(
5862
sum_udaf(),
5963
vec![expr],
6064
true,
@@ -346,6 +350,136 @@ impl AggregateUDFImpl for Sum {
346350
_ => SetMonotonicity::NotMonotonic,
347351
}
348352
}
353+
354+
/// Simplification Rules
355+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
356+
Some(Box::new(sum_simplifier))
357+
}
358+
}
359+
360+
/// Implement ClickBench Q29 specific optimization:
361+
/// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
362+
///
363+
/// Backstory: TODO
364+
///
365+
fn sum_simplifier(mut agg: AggregateFunction, info: &SimplifyContext) -> Result<Expr> {
366+
// Explicitly destructure to ensure we check all relevant fields
367+
let AggregateFunctionParams {
368+
args,
369+
distinct,
370+
filter,
371+
order_by,
372+
null_treatment,
373+
} = &agg.params;
374+
375+
if *distinct
376+
|| filter.is_some()
377+
|| !order_by.is_empty()
378+
|| null_treatment.is_some()
379+
|| args.len() != 1
380+
{
381+
return Ok(Expr::AggregateFunction(agg));
382+
}
383+
384+
// otherwise check the arguments if they are <arg> + <literal>
385+
let (arg, lit) = match SplitResult::new(agg.params.args[0].clone()) {
386+
SplitResult::Original => return Ok(Expr::AggregateFunction(agg)),
387+
SplitResult::Split { arg, lit } => (arg, lit),
388+
};
389+
390+
if !has_common_rewrite_arg(&arg, info) {
391+
return Ok(Expr::AggregateFunction(agg));
392+
}
393+
394+
// Rewrite to SUM(arg)
395+
agg.params.args = vec![arg.clone()];
396+
let sum_agg = Expr::AggregateFunction(agg);
397+
398+
// sum(arg) + scalar * COUNT(arg)
399+
Ok(sum_agg + (lit * crate::count::count(arg)))
400+
}
401+
402+
fn has_common_rewrite_arg(arg: &Expr, info: &SimplifyContext) -> bool {
403+
let Some(aggregate_exprs) = info.aggregate_exprs() else {
404+
// Only apply this rewrite in the context of an Aggregate node where
405+
// sibling aggregate expressions are known.
406+
return false;
407+
};
408+
409+
aggregate_exprs
410+
.iter()
411+
.filter_map(sum_rewrite_candidate_arg)
412+
.filter(|candidate_arg| candidate_arg == arg)
413+
.take(2)
414+
.count()
415+
> 1
416+
}
417+
418+
fn sum_rewrite_candidate_arg(expr: &Expr) -> Option<Expr> {
419+
let Expr::AggregateFunction(aggregate_fn) = expr.clone().unalias_nested().data else {
420+
return None;
421+
};
422+
if !aggregate_fn.func.name().eq_ignore_ascii_case("sum") {
423+
return None;
424+
}
425+
426+
let AggregateFunctionParams {
427+
args,
428+
distinct,
429+
filter,
430+
order_by,
431+
null_treatment,
432+
} = &aggregate_fn.params;
433+
434+
if *distinct
435+
|| filter.is_some()
436+
|| !order_by.is_empty()
437+
|| null_treatment.is_some()
438+
|| args.len() != 1
439+
{
440+
return None;
441+
}
442+
443+
match SplitResult::new(args[0].clone()) {
444+
SplitResult::Split { arg, .. } => Some(arg),
445+
SplitResult::Original => None,
446+
}
447+
}
448+
449+
/// Result of trying to split an expression into an arg and constant
450+
#[derive(Debug, Clone)]
451+
enum SplitResult {
452+
/// if the expression is either of
453+
/// * `<arg> <op> <lit>`
454+
/// * `<lit> <op> <arg>`
455+
///
456+
/// When `op` is `+`
457+
Split { arg: Expr, lit: Expr },
458+
/// If the expression is something else
459+
Original,
460+
}
461+
462+
impl SplitResult {
463+
fn new(expr: Expr) -> Self {
464+
let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
465+
return Self::Original;
466+
};
467+
if op != Operator::Plus {
468+
return Self::Original;
469+
}
470+
471+
match (left.as_ref(), right.as_ref()) {
472+
(Expr::Literal(..), _) => Self::Split {
473+
arg: *right,
474+
lit: *left,
475+
},
476+
(_, Expr::Literal(..)) => Self::Split {
477+
arg: *left,
478+
lit: *right,
479+
},
480+
_ => Self::Original,
481+
}
482+
}
349483
}
350484

351485
/// This accumulator computes SUM incrementally

0 commit comments

Comments
 (0)