Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
Expand All @@ -50,6 +51,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

/** push down project if the expression instance of PreferPushDownProject */
Expand All @@ -60,6 +62,9 @@ public List<Rule> buildRules() {
RuleType.PUSH_DOWN_PROJECT_THROUGH_JOIN.build(
logicalJoin().thenApply(this::pushDownJoinExpressions)
),
RuleType.PUSH_DOWN_PROJECT_THROUGH_JOIN.build(
logicalFilter(logicalJoin()).thenApply(this::pushDownFilterExpressions)
),
RuleType.PUSH_DOWN_PROJECT_THROUGH_JOIN.build(
logicalProject(logicalJoin()).thenApply(this::defaultPushDownProject)
),
Expand Down Expand Up @@ -137,6 +142,20 @@ private Plan pushDownJoinExpressions(MatchingContext<LogicalJoin<Plan, Plan>> ct
).withChildren(newLeft, newRight);
}

private Plan pushDownFilterExpressions(MatchingContext<LogicalFilter<LogicalJoin<Plan, Plan>>> ctx) {
LogicalFilter<LogicalJoin<Plan, Plan>> filter = ctx.root;
LogicalJoin<Plan, Plan> join = filter.child();
PushdownProjectHelper pushdownProjectHelper = new PushdownProjectHelper(ctx.statementContext, join);
Pair<Boolean, Set<Expression>> pushPredicates
= pushdownProjectHelper.pushDownExpressions(filter.getConjuncts());
if (!pushPredicates.first) {
return filter;
}

LogicalJoin<Plan, Plan> newJoin = join.withChildren(pushdownProjectHelper.buildNewChildren());
return filter.withConjuncts(pushPredicates.second).withChildren(ImmutableList.of(newJoin));
}

// return:
// key: rewrite the PreferPushDownProject to slot
// value: the pushed down project outputs which contains the Alias(PreferPushDownProject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,37 @@
package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.MatchAny;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.PreferPushDownProject;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -82,6 +96,7 @@ public class PushDownProjectTest implements MemoPatternMatchSupported {
private final LogicalOneRowRelation rel1 = new LogicalOneRowRelation(new RelationId(1), rel1Output);
private final LogicalOneRowRelation rel2 = new LogicalOneRowRelation(new RelationId(2), rel2Output);
private final List<Plan> children = Lists.newArrayList(rel1, rel2);
private final ConnectContext connectContext = MemoTestUtils.createConnectContext();

@Test
public void testPushDownProjectThroughUnionOnlyHasChildren() {
Expand Down Expand Up @@ -189,4 +204,40 @@ public void testPushDownProjectThroughUnionHasNoChildren() {
).when(p -> p.getProjects().stream().noneMatch(ne -> ne.containsType(ElementAt.class)))
);
}

@Test
public void shouldRewritePreferPushDownProjectInOrFilterToSlot() {
LogicalPlan rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student,
ImmutableList.of(""));
LogicalPlan rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score,
ImmutableList.of(""));
Expression preferPushDownProjectExpr = new MatchAny(
new Add(rStudent.getOutput().get(0), Literal.of(1)),
Literal.of("abc"));
Expression rightSidePredicate = new GreaterThan(rScore.getOutput().get(2), Literal.of(60));
Expression orPredicate = new Or(preferPushDownProjectExpr, rightSidePredicate);

LogicalPlan plan = new LogicalPlanBuilder(rStudent)
.joinEmptyOn(rScore, JoinType.INNER_JOIN)
.filter(orPredicate)
.build();

PlanChecker.from(connectContext, plan)
.applyTopDown(new PushDownProject())
.matchesFromRoot(logicalFilter(
logicalJoin(
logicalProject(logicalOlapScan())
.when(project -> project.getProjects().stream()
.filter(Alias.class::isInstance)
.map(Alias.class::cast)
.map(Alias::child)
.anyMatch(PreferPushDownProject.class::isInstance)),
logicalOlapScan()))
.when(filter -> {
Expression rewrittenPredicate = ImmutableList.copyOf(filter.getConjuncts()).get(0);
return rewrittenPredicate instanceof Or
&& rewrittenPredicate.anyMatch(SlotReference.class::isInstance)
&& !rewrittenPredicate.anyMatch(PreferPushDownProject.class::isInstance);
}));
}
}
Loading