Skip to content

Commit

Permalink
[feature](Nereids): add ColumnPruningPostProcessor. (#32800)
Browse files Browse the repository at this point in the history
(cherry picked from commit 5970f98)
  • Loading branch information
jackwener committed Apr 25, 2024
1 parent 54902f3 commit e616e2a
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Prune column for Join-Cluster
*/
@DependsRules({
MergeProjectPostProcessor.class
})
public class ColumnPruningPostProcessor extends PlanPostProcessor {
@Override
public PhysicalProject visitPhysicalProject(PhysicalProject<? extends Plan> project, CascadesContext ctx) {
Plan child = project.child();
Plan newChild = child.accept(this, ctx);
if (newChild instanceof AbstractPhysicalJoin) {
AbstractPhysicalJoin<? extends Plan, ? extends Plan> join = (AbstractPhysicalJoin) newChild;
Plan left = join.left();
Plan right = join.right();
Set<Slot> leftOutput = left.getOutputSet();
Set<Slot> rightOutput = right.getOutputSet();

Set<Slot> usedSlots = project.getProjects().stream().flatMap(ne -> ne.getInputSlots().stream())
.collect(Collectors.toSet());

Stream.concat(join.getHashJoinConjuncts().stream(), join.getOtherJoinConjuncts().stream())
.flatMap(expr -> expr.getInputSlots().stream())
.forEach(usedSlots::add);
join.getMarkJoinSlotReference().ifPresent(usedSlots::add);

List<NamedExpression> leftNewProjections = new ArrayList<>();
List<NamedExpression> rightNewProjections = new ArrayList<>();

for (Slot usedSlot : usedSlots) {
if (leftOutput.contains(usedSlot)) {
leftNewProjections.add(usedSlot);
} else if (rightOutput.contains(usedSlot)) {
rightNewProjections.add(usedSlot);
}
}

Plan newLeft;
if (left instanceof PhysicalDistribute) {
newLeft = leftNewProjections.size() != leftOutput.size() && !leftNewProjections.isEmpty()
? left.withChildren(new PhysicalProject<>(leftNewProjections,
left.getLogicalProperties(), left.child(0)))
: left;
} else {
newLeft = leftNewProjections.size() != leftOutput.size() && !leftNewProjections.isEmpty()
? new PhysicalProject<>(leftNewProjections, left.getLogicalProperties(),
left).copyStatsAndGroupIdFrom((AbstractPhysicalPlan) left)
: left;
}
Plan newRight;
if (right instanceof PhysicalDistribute) {
newRight = rightNewProjections.size() != rightOutput.size() && !rightNewProjections.isEmpty()
? right.withChildren(new PhysicalProject<>(rightNewProjections,
right.getLogicalProperties(), right.child(0)))
: right;
} else {
newRight = rightNewProjections.size() != rightOutput.size() && !rightNewProjections.isEmpty()
? new PhysicalProject<>(rightNewProjections, right.getLogicalProperties(),
right).copyStatsAndGroupIdFrom((AbstractPhysicalPlan) right)
: right;
}

if (newLeft != left || newRight != right) {
return (PhysicalProject) project.withChildren(join.withChildren(newLeft, newRight));
} else {
return project;
}
}
return project;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public List<PlanPostProcessor> getProcessors() {
// add processor if we need
Builder<PlanPostProcessor> builder = ImmutableList.builder();
builder.add(new PushdownFilterThroughProject());
builder.add(new ColumnPruningPostProcessor());
builder.add(new MergeProjectPostProcessor());
builder.add(new RecomputeLogicalPropertiesProcessor());
builder.add(new AddOffsetIntoDistribute());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ public PhysicalTopN<? extends Plan> visitPhysicalTopN(PhysicalTopN<? extends Pla
Plan child = topN.child();
topN = rewriteTopN(topN);
if (child != topN.child()) {
topN = ((PhysicalTopN<? extends Plan>) topN.withChildren(child)).copyStatsAndGroupIdFrom(topN);
topN = (PhysicalTopN<? extends Plan>) ((PhysicalTopN<? extends Plan>) topN.withChildren(
child)).copyStatsAndGroupIdFrom(topN);
}
return topN;
} else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
Expand All @@ -94,7 +95,8 @@ public Plan visitPhysicalDeferMaterializeTopN(PhysicalDeferMaterializeTopN<? ext
if (topN.getSortPhase() == SortPhase.LOCAL_SORT) {
PhysicalTopN<? extends Plan> rewrittenTopN = rewriteTopN(topN.getPhysicalTopN());
if (topN.getPhysicalTopN() != rewrittenTopN) {
topN = topN.withPhysicalTopN(rewrittenTopN).copyStatsAndGroupIdFrom(topN);
topN = (PhysicalDeferMaterializeTopN<? extends Plan>) topN.withPhysicalTopN(rewrittenTopN)
.copyStatsAndGroupIdFrom(topN);
}
return topN;
} else if (topN.getSortPhase() == SortPhase.MERGE_SORT) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public Plan getExplainPlan(ConnectContext ctx) {
return this;
}

public <T extends AbstractPhysicalPlan> T copyStatsAndGroupIdFrom(T from) {
public <T extends AbstractPhysicalPlan> AbstractPhysicalPlan copyStatsAndGroupIdFrom(T from) {
T newPlan = (T) withPhysicalPropertiesAndStats(
from.getPhysicalProperties(), from.getStats());
newPlan.setMutableState(MutableState.KEY_GROUP, from.getGroupIdAsString());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.postprocess;

import org.apache.doris.nereids.processor.post.ColumnPruningPostProcessor;
import org.apache.doris.nereids.rules.rewrite.InferFilterNotNull;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

class ColumnPruningPostProcessorTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

@Test
void test() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, ImmutableList.of())
.project(ImmutableList.of(0, 2))
.build();

PhysicalPlan physicalPlan = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new InferFilterNotNull())
.implement()
.getPhysicalPlan();

ColumnPruningPostProcessor processor = new ColumnPruningPostProcessor();
PhysicalPlan newPlan = (PhysicalPlan) physicalPlan.accept(processor, null);

Assertions.assertTrue(newPlan instanceof PhysicalProject);
Assertions.assertTrue(newPlan.child(0) instanceof PhysicalNestedLoopJoin);
Assertions.assertTrue(newPlan.child(0).child(0) instanceof PhysicalProject);
Assertions.assertTrue(newPlan.child(0).child(1) instanceof PhysicalProject);
}
}
13 changes: 7 additions & 6 deletions regression-test/data/nereids_ssb_shape_sf100_p0/shape/q3.4.out
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ PhysicalResultSink
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN](lineorder.lo_orderdate = dates.d_datekey)
------------------------hashJoin[INNER_JOIN](lineorder.lo_suppkey = supplier.s_suppkey)
--------------------------PhysicalProject
----------------------------PhysicalOlapScan[lineorder]
--------------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN](lineorder.lo_suppkey = supplier.s_suppkey)
----------------------------PhysicalProject
------------------------------filter(s_city IN ('UNITED KI1', 'UNITED KI5'))
--------------------------------PhysicalOlapScan[supplier]
------------------------------PhysicalOlapScan[lineorder]
----------------------------PhysicalDistribute
------------------------------PhysicalProject
--------------------------------filter(s_city IN ('UNITED KI1', 'UNITED KI5'))
----------------------------------PhysicalOlapScan[supplier]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------filter((dates.d_yearmonth = 'Dec1997'))
Expand Down
25 changes: 13 additions & 12 deletions regression-test/data/nereids_ssb_shape_sf100_p0/shape/q4.3.out
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@ PhysicalResultSink
------------------PhysicalDistribute
--------------------PhysicalProject
----------------------hashJoin[INNER_JOIN](lineorder.lo_orderdate = dates.d_datekey)
------------------------hashJoin[INNER_JOIN](lineorder.lo_partkey = part.p_partkey)
--------------------------PhysicalDistribute
----------------------------hashJoin[INNER_JOIN](lineorder.lo_suppkey = supplier.s_suppkey)
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[lineorder]
------------------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------hashJoin[INNER_JOIN](lineorder.lo_partkey = part.p_partkey)
----------------------------PhysicalDistribute
------------------------------hashJoin[INNER_JOIN](lineorder.lo_suppkey = supplier.s_suppkey)
--------------------------------PhysicalProject
----------------------------------filter((supplier.s_nation = 'UNITED STATES'))
------------------------------------PhysicalOlapScan[supplier]
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------filter((part.p_category = 'MFGR#14'))
--------------------------------PhysicalOlapScan[part]
----------------------------------PhysicalOlapScan[lineorder]
--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
------------------------------------filter((supplier.s_nation = 'UNITED STATES'))
--------------------------------------PhysicalOlapScan[supplier]
----------------------------PhysicalDistribute
------------------------------PhysicalProject
--------------------------------filter((part.p_category = 'MFGR#14'))
----------------------------------PhysicalOlapScan[part]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------filter(d_year IN (1997, 1998))
Expand Down

0 comments on commit e616e2a

Please sign in to comment.