Skip to content

Commit 7654dd8

Browse files
committed
[CALCITE-7362] Add rule to transform WHERE clauses into filtered aggregates
1 parent 80e717a commit 7654dd8

4 files changed

Lines changed: 403 additions & 0 deletions

File tree

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to you under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.calcite.rel.rules;
18+
19+
import org.apache.calcite.plan.RelOptRuleCall;
20+
import org.apache.calcite.plan.RelRule;
21+
import org.apache.calcite.rel.core.Aggregate;
22+
import org.apache.calcite.rel.core.AggregateCall;
23+
import org.apache.calcite.rel.core.Filter;
24+
import org.apache.calcite.rex.RexNode;
25+
import org.apache.calcite.tools.RelBuilder;
26+
27+
import org.immutables.value.Value;
28+
29+
import java.util.ArrayList;
30+
import java.util.List;
31+
32+
/**
33+
* Rule that converts an aggregate on top of a filter into a filtered aggregate.
34+
*
35+
* <p>Before
36+
* <pre><code>
37+
* SELECT SUM(salary)
38+
* FROM Emp
39+
* WHERE deptno = 10
40+
* </code></pre>
41+
*
42+
* <p>After
43+
* <pre><code>
44+
* SELECT SUM(salary) FILTER (WHERE deptno = 10)
45+
* FROM Emp
46+
* </code></pre>
47+
*
48+
* <p>The transformation is particularly useful in view-based rewriting.
49+
* The removal of the {@code Filter} operators lifts some restrictions when using
50+
* the {@link org.apache.calcite.rel.rules.materialize.MaterializedViewRules}.
51+
*
52+
* <p>Filtered aggregates can be transformed to other equivalent forms via other
53+
* transformation rules (e.g., {@link AggregateFilterToCaseRule}).
54+
*/
55+
@Value.Enclosing public class AggregateFilterToFilteredAggregateRule
56+
extends RelRule<AggregateFilterToFilteredAggregateRule.Config> {
57+
58+
private AggregateFilterToFilteredAggregateRule(Config config) {
59+
super(config);
60+
}
61+
62+
@Override public void onMatch(RelOptRuleCall call) {
63+
Aggregate aggregate = call.rel(0);
64+
Filter filter = call.rel(1);
65+
if (!aggregate.getGroupSet().isEmpty()) {
66+
// At the moment we only support the transformation for grand totals, i.e.,
67+
// aggregates with no grouping keys.
68+
return;
69+
}
70+
RelBuilder builder = call.builder();
71+
builder.push(filter.getInput());
72+
List<RexNode> projects = new ArrayList<>(builder.fields());
73+
List<AggregateCall> newAggCalls = new ArrayList<>();
74+
for (AggregateCall aggCall : aggregate.getAggCallList()) {
75+
if (!aggCall.getAggregation().allowsFilter()) {
76+
return;
77+
}
78+
RexNode condition = filter.getCondition();
79+
// If the aggregate call has its own filter, combine it with the filter condition.
80+
if (aggCall.hasFilter()) {
81+
condition = builder.and(condition, builder.field(aggCall.filterArg));
82+
}
83+
int pos = projects.indexOf(condition);
84+
if (pos < 0) {
85+
pos = projects.size();
86+
projects.add(condition);
87+
}
88+
newAggCalls.add(aggCall.withFilter(pos));
89+
}
90+
builder.project(projects);
91+
builder.aggregate(builder.groupKey(), newAggCalls);
92+
call.transformTo(builder.build());
93+
}
94+
95+
/** Rule configuration. */
96+
@Value.Immutable public interface Config extends RelRule.Config {
97+
Config DEFAULT = ImmutableAggregateFilterToFilteredAggregateRule.Config.of()
98+
.withOperandSupplier(
99+
a -> a.operand(Aggregate.class).oneInput(f -> f.operand(Filter.class).anyInputs()));
100+
101+
@Override default AggregateFilterToFilteredAggregateRule toRule() {
102+
return new AggregateFilterToFilteredAggregateRule(this);
103+
}
104+
}
105+
}

core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,11 @@ private CoreRules() {}
959959
public static final AggregateFilterToCaseRule AGGREGATE_FILTER_TO_CASE =
960960
AggregateFilterToCaseRule.Config.DEFAULT.toRule();
961961

962+
/** Rule that converts an aggregate on of a filter into a filtered aggregate. */
963+
public static final AggregateFilterToFilteredAggregateRule
964+
AGGREGATE_FILTER_TO_FILTERED_AGGREGATE =
965+
AggregateFilterToFilteredAggregateRule.Config.DEFAULT.toRule();
966+
962967
/** Rule that remove duplicate {@link Sort} keys. */
963968
public static final SortRemoveDuplicateKeysRule SORT_REMOVE_DUPLICATE_KEYS =
964969
SortRemoveDuplicateKeysRule.Config.DEFAULT.toRule();
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to you under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.calcite.test;
18+
19+
import org.apache.calcite.plan.RelOptRule;
20+
import org.apache.calcite.plan.hep.HepProgram;
21+
import org.apache.calcite.rel.rules.AggregateFilterToFilteredAggregateRule;
22+
import org.apache.calcite.rel.rules.CoreRules;
23+
24+
import org.junit.jupiter.api.AfterAll;
25+
import org.junit.jupiter.api.Test;
26+
27+
import java.util.ArrayList;
28+
import java.util.List;
29+
30+
import static org.apache.calcite.rel.rules.CoreRules.AGGREGATE_FILTER_TO_FILTERED_AGGREGATE;
31+
import static org.apache.calcite.rel.rules.CoreRules.AGGREGATE_PROJECT_MERGE;
32+
import static org.apache.calcite.rel.rules.CoreRules.PROJECT_FILTER_TRANSPOSE_WHOLE_PROJECT_EXPRESSIONS;
33+
34+
/**
35+
* Unit tests for {@link AggregateFilterToFilteredAggregateRule}.
36+
*
37+
* <p>Relevant tickets:
38+
* <ul>
39+
* <li><a href="https://issues.apache.org/jira/browse/CALCITE-7362">
40+
* [CALCITE-7362] Add rule to transform WHERE clauses into filtered aggregates
41+
* </a></li>
42+
* </ul>
43+
*/
44+
class AggregateFilterToFilteredAggregateRuleTest {
45+
46+
private static RelOptFixture fixture() {
47+
return RelOptFixture.DEFAULT.withDiffRepos(
48+
DiffRepository.lookup(AggregateFilterToFilteredAggregateRuleTest.class));
49+
}
50+
51+
private static RelOptFixture sql(String sql) {
52+
return fixture().sql(sql);
53+
}
54+
55+
@Test void testSingleColumnAggregate() {
56+
String sql = "select sum(sal) from emp where deptno = 10";
57+
sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
58+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
59+
}
60+
61+
@Test void testSingleStarAggregate() {
62+
String sql = "select count(*) from emp where deptno = 10";
63+
sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
64+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
65+
}
66+
67+
@Test void testMultiAggregates() {
68+
String sql = "select sum(sal), min(sal), max(sal), count(*) from emp where deptno = 10";
69+
sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
70+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
71+
}
72+
73+
@Test void testSingleColumnFilteredAggregate() {
74+
String sql = "select sum(sal) filter (where ename = 'Bob') from emp where deptno = 10";
75+
List<RelOptRule> preRules = new ArrayList<>();
76+
preRules.add(AGGREGATE_PROJECT_MERGE);
77+
preRules.add(PROJECT_FILTER_TRANSPOSE_WHOLE_PROJECT_EXPRESSIONS);
78+
sql(sql).withPre(HepProgram.builder().addRuleCollection(preRules).build())
79+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE,
80+
CoreRules.PROJECT_MERGE).check();
81+
}
82+
83+
@Test void testAggregateNoSupportingFilter() {
84+
String sql = "select single_value(sal) from emp where deptno = 10";
85+
sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
86+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE)
87+
.checkUnchanged();
88+
}
89+
90+
@Test void testSingleColumnAggregateWithGroupBy() {
91+
String sql = "select sum(sal) from emp where deptno = 10 group by job";
92+
sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
93+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE)
94+
.checkUnchanged();
95+
}
96+
97+
@Test void testSingleColumnAggregateWithGroupingSets() {
98+
String sql =
99+
"select sum(sal) from emp where deptno = 10 group by grouping sets ((job), (ename))";
100+
sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
101+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE)
102+
.checkUnchanged();
103+
}
104+
105+
@Test void testSingleColumnAggregateWithEmptyGroupBy() {
106+
String sql = "select sum(sal) from emp where deptno = 10 group by ()";
107+
sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
108+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
109+
}
110+
111+
@Test void testSingleColumnAggregateWithEmptyGroupingSets() {
112+
String sql = "select sum(sal) from emp where deptno = 10 group by grouping sets (())";
113+
sql(sql).withPreRule(AGGREGATE_PROJECT_MERGE)
114+
.withRule(AGGREGATE_FILTER_TO_FILTERED_AGGREGATE).check();
115+
}
116+
117+
@AfterAll static void checkActualAndReferenceFiles() {
118+
fixture().diffRepos.checkActualAndReferenceFiles();
119+
}
120+
}

0 commit comments

Comments
 (0)