Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion datafusion/optimizer/src/eliminate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,11 @@ mod tests {
Sort: test.a ASC NULLS LAST, fetch=3
Limit: skip=0, fetch=2
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
RightSemi Join: test.a = test.a
Limit: skip=0, fetch=2
Aggregate: groupBy=[[test.a]], aggr=[[]]
TableScan: test
TableScan: test
"
)
}
Expand Down
179 changes: 165 additions & 14 deletions datafusion/optimizer/src/push_down_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
//! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan

use std::cmp::min;
use std::collections::HashSet;
use std::sync::Arc;

use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::Result;
use datafusion_common::tree_node::Transformed;
use datafusion_common::utils::combine_limit;
use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan};
use datafusion_expr::{FetchType, SkipType, lit};
use datafusion_common::{NullEquality, Result, get_required_group_by_exprs_indices};
use datafusion_expr::logical_plan::{Aggregate, Join, JoinType, Limit, LogicalPlan};
use datafusion_expr::{Expr, FetchType, LogicalPlanBuilder, SkipType, lit};

/// Optimization rule that tries to push down `LIMIT`.
//. It will push down through projection, limits (taking the smaller limit)
Expand All @@ -47,7 +48,6 @@ impl OptimizerRule for PushDownLimit {
true
}

#[expect(clippy::only_used_in_recursion)]
fn rewrite(
&self,
plan: LogicalPlan,
Expand Down Expand Up @@ -123,6 +123,21 @@ impl OptimizerRule for PushDownLimit {
make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join)))
})),

LogicalPlan::Aggregate(aggregate)
if config
.options()
.optimizer
.enable_distinct_aggregation_soft_limit =>
{
if let Some(aggregate) =
prefilter_limited_aggregate(aggregate.clone(), fetch + skip)?
{
transformed_limit(skip, fetch, aggregate)
} else {
original_limit(skip, fetch, LogicalPlan::Aggregate(aggregate))
}
}

LogicalPlan::Sort(mut sort) => {
let new_fetch = {
let sort_fetch = skip + fetch;
Expand Down Expand Up @@ -237,6 +252,99 @@ fn transformed_limit(
Ok(Transformed::yes(make_limit(skip, fetch, Arc::new(input))))
}

/// Rewrite `LIMIT K (GROUP BY keys, aggs)` into a key preselection followed
/// by a semi join. This keeps the aggregate itself ordinary while letting the
/// join's dynamic filter push the selected key set into the second input scan.
fn prefilter_limited_aggregate(
aggregate: Aggregate,
limit: usize,
) -> Result<Option<LogicalPlan>> {
if limit == 0 || aggregate.aggr_expr.is_empty() || aggregate.group_expr.is_empty() {
return Ok(None);
}
if is_key_prefiltered_aggregate(&aggregate) {
return Ok(None);
}
if has_functionally_reducible_group_exprs(&aggregate) {
return Ok(None);
}

let mut seen_columns = HashSet::with_capacity(aggregate.group_expr.len());
let mut join_columns = Vec::with_capacity(aggregate.group_expr.len());
for expr in &aggregate.group_expr {
let Expr::Column(column) = expr else {
return Ok(None);
};
if !seen_columns.insert(column.clone()) {
return Ok(None);
}
join_columns.push(column.clone());
}

let key_input = aggregate.input.as_ref().clone();
let keys = LogicalPlanBuilder::from(key_input)
.aggregate(aggregate.group_expr.clone(), Vec::<Expr>::new())?
.limit(0, Some(limit))?
.build()?;

let filtered_input = LogicalPlanBuilder::from(keys)
.join_detailed(
aggregate.input.as_ref().clone(),
JoinType::RightSemi,
(join_columns.clone(), join_columns),
None,
NullEquality::NullEqualsNull,
)?
.build()?;

Aggregate::try_new(
Arc::new(filtered_input),
aggregate.group_expr,
aggregate.aggr_expr,
)
.map(LogicalPlan::Aggregate)
.map(Some)
}

fn has_functionally_reducible_group_exprs(aggregate: &Aggregate) -> bool {
if aggregate
.input
.schema()
.functional_dependencies()
.is_empty()
{
return false;
}

let group_expr_names = aggregate
.group_expr
.iter()
.map(|expr| expr.schema_name().to_string())
.collect::<Vec<_>>();

get_required_group_by_exprs_indices(aggregate.input.schema(), &group_expr_names)
.is_some_and(|required_indices| {
required_indices.len() < aggregate.group_expr.len()
})
}

fn is_key_prefiltered_aggregate(aggregate: &Aggregate) -> bool {
let LogicalPlan::Join(join) = aggregate.input.as_ref() else {
return false;
};
if join.join_type != JoinType::RightSemi {
return false;
}
let LogicalPlan::Limit(limit) = join.left.as_ref() else {
return false;
};
let LogicalPlan::Aggregate(keys) = limit.input.as_ref() else {
return false;
};

keys.aggr_expr.is_empty() && keys.group_expr == aggregate.group_expr
}

/// Adds a limit to the inputs of a join, if possible
fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {
use JoinType::*;
Expand Down Expand Up @@ -279,10 +387,11 @@ mod test {
use crate::test::*;

use crate::OptimizerContext;
use datafusion_common::DFSchemaRef;
use arrow::datatypes::Schema;
use datafusion_common::{Constraint, Constraints, DFSchemaRef};
use datafusion_expr::{
Expr, Extension, UserDefinedLogicalNodeCore, col, exists,
logical_plan::builder::LogicalPlanBuilder,
logical_plan::builder::{LogicalPlanBuilder, table_source_with_constraints},
};
use datafusion_functions_aggregate::expr_fn::max;

Expand Down Expand Up @@ -583,20 +692,52 @@ mod test {
}

#[test]
fn limit_doesnt_push_down_aggregation() -> Result<()> {
fn limit_prefilters_aggregation() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![max(col("b"))])?
.limit(0, Some(1000))?
.build()?;

// Limit should *not* push down aggregate node
// Limit preselects group keys before running the aggregate
assert_optimized_plan_equal!(
plan,
@r"
Limit: skip=0, fetch=1000
Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]
RightSemi Join: test.a = test.a
Limit: skip=0, fetch=1000
Aggregate: groupBy=[[test.a]], aggr=[[]]
TableScan: test
TableScan: test
"
)
}

#[test]
fn limit_does_not_prefilter_fd_reducible_aggregation() -> Result<()> {
let constraints =
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
let table_source = table_source_with_constraints(
&Schema::new(test_table_scan_fields()),
constraints,
);
let table_scan = LogicalPlanBuilder::scan("test", table_source, None)?.build()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a"), col("b"), col("c")], vec![max(col("b"))])?
.limit(0, Some(1000))?
.build()?;

// SQL planning may add functionally dependent fields as implicit group
// keys. Do not turn those redundant keys into semijoin predicates before
// projection optimization has a chance to simplify them.
assert_optimized_plan_equal!(
plan,
@r"
Limit: skip=0, fetch=1000
Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[max(test.b)]]
TableScan: test
"
)
Expand Down Expand Up @@ -675,14 +816,20 @@ mod test {
.limit(0, Some(10))?
.build()?;

// Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation
// Limit should use deeper LIMIT 1000 and preselect group keys for the
// aggregate using the outer LIMIT 10.
assert_optimized_plan_equal!(
plan,
@r"
Limit: skip=0, fetch=10
Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]
Limit: skip=0, fetch=1000
TableScan: test, fetch=1000
RightSemi Join: test.a = test.a
Limit: skip=0, fetch=10
Aggregate: groupBy=[[test.a]], aggr=[[]]
Limit: skip=0, fetch=1000
TableScan: test, fetch=1000
Limit: skip=0, fetch=1000
TableScan: test, fetch=1000
"
)
}
Expand Down Expand Up @@ -786,21 +933,25 @@ mod test {
}

#[test]
fn limit_doesnt_push_down_with_offset_aggregation() -> Result<()> {
fn limit_with_offset_prefilters_aggregation() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![max(col("b"))])?
.limit(10, Some(1000))?
.build()?;

// Limit should *not* push down aggregate node
// Limit preselects enough group keys to satisfy offset and fetch
assert_optimized_plan_equal!(
plan,
@r"
Limit: skip=10, fetch=1000
Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]
TableScan: test
RightSemi Join: test.a = test.a
Limit: skip=0, fetch=1010
Aggregate: groupBy=[[test.a]], aggr=[[]]
TableScan: test
TableScan: test
"
)
}
Expand Down
19 changes: 15 additions & 4 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -6607,23 +6607,34 @@ SELECT DISTINCT c3 FROM aggregate_test_100 WHERE c3 between 10 and 20 group by c
14
17

# An aggregate expression causes the limit to not be pushed to the aggregation
# An aggregate expression prefilters the input through a limited group-key set
query TT
EXPLAIN SELECT max(c1), c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5;
----
logical_plan
01)Projection: max(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3
02)--Limit: skip=0, fetch=5
03)----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[max(aggregate_test_100.c1)]]
04)------TableScan: aggregate_test_100 projection=[c1, c2, c3]
04)------RightSemi Join: aggregate_test_100.c2 = aggregate_test_100.c2, aggregate_test_100.c3 = aggregate_test_100.c3
05)--------Limit: skip=0, fetch=5
06)----------Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[]]
07)------------TableScan: aggregate_test_100 projection=[c2, c3]
08)--------TableScan: aggregate_test_100 projection=[c1, c2, c3]
physical_plan
01)ProjectionExec: expr=[max(aggregate_test_100.c1)@2 as max(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3]
02)--GlobalLimitExec: skip=0, fetch=5
03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[max(aggregate_test_100.c1)]
04)------CoalescePartitionsExec
05)--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[max(aggregate_test_100.c1)]
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true
06)----------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(c2@0, c2@1), (c3@1, c3@2)], NullsEqual: true
07)------------GlobalLimitExec: skip=0, fetch=5
08)--------------AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[5]
09)----------------CoalescePartitionsExec
10)------------------AggregateExec: mode=Partial, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[5]
11)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
12)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c2, c3], file_type=csv, has_header=true
13)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
14)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100_with_dates.csv]]}, projection=[c1, c2, c3], file_type=csv, has_header=true

# TODO(msirek): Extend checking in LimitedDistinctAggregation equal groupings to ignore the order of columns
# in the group-by column lists, so the limit could be pushed to the lowest AggregateExec in this case
Expand Down
20 changes: 16 additions & 4 deletions datafusion/sqllogictest/test_files/clickbench.slt
Original file line number Diff line number Diff line change
Expand Up @@ -490,15 +490,27 @@ logical_plan
01)Projection: hits.UserID, hits.SearchPhrase, count(Int64(1)) AS count(*)
02)--Limit: skip=0, fetch=10
03)----Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]], aggr=[[count(Int64(1))]]
04)------SubqueryAlias: hits
05)--------TableScan: hits_raw projection=[UserID, SearchPhrase]
04)------RightSemi Join: hits.UserID = hits.UserID, hits.SearchPhrase = hits.SearchPhrase
05)--------Limit: skip=0, fetch=10
06)----------Aggregate: groupBy=[[hits.UserID, hits.SearchPhrase]], aggr=[[]]
07)------------SubqueryAlias: hits
08)--------------TableScan: hits_raw projection=[UserID, SearchPhrase]
09)--------SubqueryAlias: hits
10)----------TableScan: hits_raw projection=[UserID, SearchPhrase]
physical_plan
01)ProjectionExec: expr=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase, count(Int64(1))@2 as count(*)]
02)--CoalescePartitionsExec: fetch=10
03)----AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))]
04)------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=1
04)------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=4
05)--------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[count(Int64(1))]
06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
07)------------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(UserID@0, UserID@0), (SearchPhrase@1, SearchPhrase@1)], NullsEqual: true
08)--------------CoalescePartitionsExec: fetch=10
09)----------------AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[], lim=[10]
10)------------------RepartitionExec: partitioning=Hash([UserID@0, SearchPhrase@1], 4), input_partitions=1
11)--------------------AggregateExec: mode=Partial, gby=[UserID@0 as UserID, SearchPhrase@1 as SearchPhrase], aggr=[], lim=[10]
12)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet
13)--------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, projection=[UserID, SearchPhrase], file_type=parquet, predicate=DynamicFilter [ empty ]

query ITI rowsort
SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPhrase" LIMIT 10;
Expand Down
Loading