diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java index 832f9c25e776f2..432d25586a0450 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownProject.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -50,6 +51,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.Function; /** push down project if the expression instance of PreferPushDownProject */ @@ -60,6 +62,9 @@ public List buildRules() { RuleType.PUSH_DOWN_PROJECT_THROUGH_JOIN.build( logicalJoin().thenApply(this::pushDownJoinExpressions) ), + RuleType.PUSH_DOWN_PROJECT_THROUGH_JOIN.build( + logicalFilter(logicalJoin()).thenApply(this::pushDownFilterExpressions) + ), RuleType.PUSH_DOWN_PROJECT_THROUGH_JOIN.build( logicalProject(logicalJoin()).thenApply(this::defaultPushDownProject) ), @@ -137,6 +142,20 @@ private Plan pushDownJoinExpressions(MatchingContext> ct ).withChildren(newLeft, newRight); } + private Plan pushDownFilterExpressions(MatchingContext>> ctx) { + LogicalFilter> filter = ctx.root; + LogicalJoin join = filter.child(); + PushdownProjectHelper pushdownProjectHelper = new PushdownProjectHelper(ctx.statementContext, join); + Pair> pushPredicates + = pushdownProjectHelper.pushDownExpressions(filter.getConjuncts()); + if (!pushPredicates.first) { + return filter; + } + + LogicalJoin newJoin = join.withChildren(pushdownProjectHelper.buildNewChildren()); + return filter.withConjuncts(pushPredicates.second).withChildren(ImmutableList.of(newJoin)); + } + // return: // key: rewrite the PreferPushDownProject to slot // value: the pushed down project outputs which contains the Alias(PreferPushDownProject) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java index 47398e3ef9a458..0a0b425e81766a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownProjectTest.java @@ -18,23 +18,37 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.MatchAny; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.PreferPushDownProject; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.RelationId; import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; +import org.apache.doris.qe.ConnectContext; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import org.junit.jupiter.api.Test; @@ -82,6 +96,7 @@ public class PushDownProjectTest implements MemoPatternMatchSupported { private final LogicalOneRowRelation rel1 = new LogicalOneRowRelation(new RelationId(1), rel1Output); private final LogicalOneRowRelation rel2 = new LogicalOneRowRelation(new RelationId(2), rel2Output); private final List children = Lists.newArrayList(rel1, rel2); + private final ConnectContext connectContext = MemoTestUtils.createConnectContext(); @Test public void testPushDownProjectThroughUnionOnlyHasChildren() { @@ -189,4 +204,40 @@ public void testPushDownProjectThroughUnionHasNoChildren() { ).when(p -> p.getProjects().stream().noneMatch(ne -> ne.containsType(ElementAt.class))) ); } + + @Test + public void shouldRewritePreferPushDownProjectInOrFilterToSlot() { + LogicalPlan rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, + ImmutableList.of("")); + LogicalPlan rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, + ImmutableList.of("")); + Expression preferPushDownProjectExpr = new MatchAny( + new Add(rStudent.getOutput().get(0), Literal.of(1)), + Literal.of("abc")); + Expression rightSidePredicate = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); + Expression orPredicate = new Or(preferPushDownProjectExpr, rightSidePredicate); + + LogicalPlan plan = new LogicalPlanBuilder(rStudent) + .joinEmptyOn(rScore, JoinType.INNER_JOIN) + .filter(orPredicate) + .build(); + + PlanChecker.from(connectContext, plan) + .applyTopDown(new PushDownProject()) + .matchesFromRoot(logicalFilter( + logicalJoin( + logicalProject(logicalOlapScan()) + .when(project -> project.getProjects().stream() + .filter(Alias.class::isInstance) + .map(Alias.class::cast) + .map(Alias::child) + .anyMatch(PreferPushDownProject.class::isInstance)), + logicalOlapScan())) + .when(filter -> { + Expression rewrittenPredicate = ImmutableList.copyOf(filter.getConjuncts()).get(0); + return rewrittenPredicate instanceof Or + && rewrittenPredicate.anyMatch(SlotReference.class::isInstance) + && !rewrittenPredicate.anyMatch(PreferPushDownProject.class::isInstance); + })); + } }