Skip to content

Commit

Permalink
push down RF into cte producer
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Jan 30, 2024
1 parent cd3ef3c commit 0335396
Show file tree
Hide file tree
Showing 12 changed files with 320 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public List<PlanPostProcessor> getProcessors() {
builder.add(new FragmentProcessor());
if (!cascadesContext.getConnectContext().getSessionVariable().getRuntimeFilterMode()
.toUpperCase().equals(TRuntimeFilterMode.OFF.name())) {
builder.add(new RegisterParent());
builder.add(new RuntimeFilterGenerator());
if (ConnectContext.get().getSessionVariable().enableRuntimeFilterPrune) {
builder.add(new RuntimeFilterPruner());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.MutableState;

/**
* set parent for the tree nodes
*/
public class RegisterParent extends PlanPostProcessor {
@Override
public Plan visit(Plan plan, CascadesContext context) {
for (Plan child : plan.children()) {
child.setMutableState(MutableState.KEY_PARENT, plan);
child.accept(this, context);
}
return plan;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -99,6 +101,136 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {

private final IdGenerator<RuntimeFilterId> generator = RuntimeFilterId.createGenerator();

@Override
public Plan processRoot(Plan plan, CascadesContext ctx) {
Plan result = plan.accept(this, ctx);
// cte rf
RuntimeFilterContext rfCtx = ctx.getRuntimeFilterContext();
int cteCount = rfCtx.getProcessedCTE().size();
if (cteCount != 0) {
Map<CTEId, Set<PhysicalCTEConsumer>> cteIdToConsumersWithRF = Maps.newHashMap();
Map<CTEId, List<RuntimeFilter>> cteToRFsMap = Maps.newHashMap();
Map<PhysicalCTEConsumer, Set<RuntimeFilter>> consumerToRFs = Maps.newHashMap();
Map<PhysicalCTEConsumer, Set<Expression>> consumerToSrcExpression = Maps.newHashMap();
List<RuntimeFilter> allRFs = rfCtx.getNereidsRuntimeFilter();
for (RuntimeFilter rf : allRFs) {
for (PhysicalRelation rel : rf.getTargetScans()) {
if (rel instanceof PhysicalCTEConsumer) {
PhysicalCTEConsumer consumer = (PhysicalCTEConsumer) rel;
CTEId cteId = consumer.getCteId();
cteToRFsMap.computeIfAbsent(cteId, key -> Lists.newArrayList()).add(rf);
cteIdToConsumersWithRF.computeIfAbsent(cteId, key -> Sets.newHashSet()).add(consumer);
consumerToRFs.computeIfAbsent(consumer, key -> Sets.newHashSet()).add(rf);
consumerToSrcExpression.computeIfAbsent(consumer, key -> Sets.newHashSet())
.add(rf.getSrcExpr());
}
}
}
for (CTEId cteId : rfCtx.getCteProduceMap().keySet()) {
// if any consumer does not have RF, RF cannot be pushed down.
// cteIdToConsumersWithRF.get(cteId).size() can not be 1, o.w. this cte will be inlined.
if (ctx.getCteIdToConsumers().get(cteId).size() == cteIdToConsumersWithRF.get(cteId).size()
&& cteIdToConsumersWithRF.get(cteId).size() >= 2) {
// check if there is a common srcExpr among all the consumers
Set<PhysicalCTEConsumer> consumers = cteIdToConsumersWithRF.get(cteId);
PhysicalCTEConsumer consumer0 = consumers.iterator().next();
Set<Expression> candidateSrcExpressions = consumerToSrcExpression.get(consumer0);
for (PhysicalCTEConsumer currentConsumer : consumers) {
Set<Expression> srcExpressionsOnCurrentConsumer = consumerToSrcExpression.get(currentConsumer);
candidateSrcExpressions.retainAll(srcExpressionsOnCurrentConsumer);
if (candidateSrcExpressions.isEmpty()) {
break;
}
}
if (!candidateSrcExpressions.isEmpty()) {
// find RFs to push down
for (Expression srcExpr : candidateSrcExpressions) {
List<RuntimeFilter> rfsToPushDown = Lists.newArrayList();
for (PhysicalCTEConsumer consumer : cteIdToConsumersWithRF.get(cteId)) {
for (RuntimeFilter rf : consumerToRFs.get(consumer)) {
if (rf.getSrcExpr().equals(srcExpr)) {
rfsToPushDown.add(rf);
}
}
}
if (rfsToPushDown.isEmpty()) {
break;
}

// the most right deep buildNode from rfsToPushDown is used as buildNode for pushDown rf
// since the srcExpr are the same, all buildNodes of rfToPushDown are in the same tree path
// the longest ancestors means its corresponding rf build node is the most right deep one.
RuntimeFilter rightDeep = rfsToPushDown.get(0);
List<Plan> rightDeepAncestors = rfsToPushDown.get(0).getBuilderNode().getAncestors();
int rightDeepAncestorsSize = rightDeepAncestors.size();
RuntimeFilter leftTop = rfsToPushDown.get(0);
int leftTopAncestorsSize = rightDeepAncestorsSize;
for (RuntimeFilter rf : rfsToPushDown) {
List<Plan> ancestors = rf.getBuilderNode().getAncestors();
int currentAncestorsSize = ancestors.size();
if (currentAncestorsSize > rightDeepAncestorsSize) {
rightDeep = rf;
rightDeepAncestorsSize = currentAncestorsSize;
rightDeepAncestors = ancestors;
}
if (currentAncestorsSize < leftTopAncestorsSize) {
leftTopAncestorsSize = currentAncestorsSize;
leftTop = rf;
}
}
Preconditions.checkArgument(rightDeepAncestors.contains(leftTop.getBuilderNode()));
// check nodes between right deep and left top are SPJ and not denied join and not mark join
boolean valid = true;
for (Plan cursor : rightDeepAncestors) {
if (cursor.equals(leftTop.getBuilderNode())) {
break;
}
valid = valid && SPJ_PLAN.contains(cursor.getClass());
if (cursor instanceof AbstractPhysicalJoin) {
AbstractPhysicalJoin cursorJoin = (AbstractPhysicalJoin) cursor;
valid = (!RuntimeFilterGenerator.DENIED_JOIN_TYPES
.contains(cursorJoin.getJoinType())
|| cursorJoin.isMarkJoin()) && valid;
}
if (!valid) {
break;
}
}

if (!valid) {
break;
}

Expression rightDeepTargetExpressionOnCTE = null;
int targetCount = rightDeep.getTargetExpressions().size();
for (int i = 0; i < targetCount; i++) {
PhysicalRelation rel = rightDeep.getTargetScans().get(i);
if (rel instanceof PhysicalCTEConsumer
&& ((PhysicalCTEConsumer) rel).getCteId().equals(cteId)) {
rightDeepTargetExpressionOnCTE = rightDeep.getTargetExpressions().get(i);
break;
}
}

boolean pushedDown = doPushDownIntoCTEProducerInternal(
rightDeep,
rightDeepTargetExpressionOnCTE,
rfCtx,
rfCtx.getCteProduceMap().get(cteId)
);
if (pushedDown) {
rfCtx.removeFilter(
rightDeepTargetExpressionOnCTE.getInputSlotExprIds().iterator().next(),
(PhysicalHashJoin) rightDeep.getBuilderNode());
}
}
}
}
}
}
return result;
}

/**
* the runtime filter generator run at the phase of post process and plan translation of nereids planner.
* post process:
Expand All @@ -117,19 +249,20 @@ public class RuntimeFilterGenerator extends PlanPostProcessor {
@Override
public PhysicalPlan visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
CascadesContext context) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
join.right().accept(this, context);
join.left().accept(this, context);
if (RuntimeFilterGenerator.DENIED_JOIN_TYPES.contains(join.getJoinType()) || join.isMarkJoin()) {
join.right().getOutput().forEach(slot ->
context.getRuntimeFilterContext().aliasTransferMapRemove(slot));
}
collectPushDownCTEInfos(join, context);
if (!getPushDownCTECandidates(ctx).isEmpty()) {
pushDownRuntimeFilterIntoCTE(ctx);
} else {
pushDownRuntimeFilterCommon(join, context);
}
// collectPushDownCTEInfos(join, context);
// if (!getPushDownCTECandidates(ctx).isEmpty()) {
// pushDownRuntimeFilterCommon(join, context);
// pushDownRuntimeFilterIntoCTE(ctx);
// } else {
// pushDownRuntimeFilterCommon(join, context);
// }
pushDownRuntimeFilterCommon(join, context);
return join;
}

Expand Down Expand Up @@ -381,6 +514,14 @@ private void pushDownRuntimeFilterCommon(PhysicalHashJoin<? extends Plan, ? exte
continue;
}
long buildSideNdv = getBuildSideNdv(join, equalTo);
Pair<PhysicalRelation, Slot> pair = ctx.getAliasTransferMap().get(equalTo.right());
if (pair == null) {
continue;
}
if (pair.first instanceof PhysicalCTEConsumer) {
// CteConsumer is not allowed to generate RF in order to avoid RF cycle.
continue;
}
join.pushDownRuntimeFilter(context, generator, join, equalTo.right(),
equalTo.left(), type, buildSideNdv, i);
}
Expand Down Expand Up @@ -579,6 +720,65 @@ private void pushDownRuntimeFilterIntoCTE(RuntimeFilterContext ctx) {
}
}

private boolean doPushDownIntoCTEProducerInternal(RuntimeFilter rf, Expression targetExpression,
RuntimeFilterContext ctx, PhysicalCTEProducer cteProducer) {
PhysicalPlan inputPlanNode = (PhysicalPlan) cteProducer.child(0);
Slot unwrappedSlot = checkTargetChild(targetExpression);
// aliasTransMap doesn't contain the key, means that the path from the scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (!checkPushDownPreconditionsForJoin(rf.getBuilderNode(), ctx, unwrappedSlot)) {
return false;
}
Slot cteSlot = ctx.getAliasTransferPair(unwrappedSlot).second;
PhysicalRelation cteNode = ctx.getAliasTransferPair(unwrappedSlot).first;
long buildSideNdv = rf.getBuildSideNdv();
if (cteNode instanceof PhysicalCTEConsumer && inputPlanNode instanceof PhysicalProject) {
PhysicalProject project = (PhysicalProject) inputPlanNode;
NamedExpression targetExpr = null;
for (Object column : project.getProjects()) {
NamedExpression alias = (NamedExpression) column;
if (cteSlot.getName().equals(alias.getName())) {
targetExpr = alias;
break;
}
}
Preconditions.checkState(targetExpr != null);
if (targetExpr instanceof SlotReference && checkCanPushDownIntoBasicTable(project)) {
Map<Slot, PhysicalRelation> pushDownBasicTableInfos = getPushDownBasicTablesInfos(project,
(SlotReference) targetExpr, ctx);
if (!pushDownBasicTableInfos.isEmpty()) {
List<Slot> targetList = new ArrayList<>();
List<Expression> targetExpressions = new ArrayList<>();
List<PhysicalRelation> targetNodes = new ArrayList<>();
for (Map.Entry<Slot, PhysicalRelation> entry : pushDownBasicTableInfos.entrySet()) {
Slot targetSlot = entry.getKey();
PhysicalRelation scan = entry.getValue();
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForRelation(project, scan)) {
continue;
}
targetList.add(targetSlot);
targetExpressions.add(targetSlot);
targetNodes.add(scan);
ctx.addJoinToTargetMap(rf.getBuilderNode(), targetSlot.getExprId());
ctx.setTargetsOnScanNode(scan, targetSlot);
}

RuntimeFilter filter = new RuntimeFilter(generator.getNextId(),
rf.getSrcExpr(), targetList, targetExpressions, rf.getType(), rf.getExprOrder(),
rf.getBuilderNode(), buildSideNdv, rf.isBloomFilterSizeCalculatedByNdv(),
cteNode);
targetNodes.forEach(node -> node.addAppliedRuntimeFilter(filter));
for (Slot slot : targetList) {
ctx.setTargetExprIdToFilter(slot.getExprId(), filter);
}
ctx.setRuntimeFilterIdentityToFilter(rf.getSrcExpr(), rf.getType(), rf.getBuilderNode(), filter);
return true;
}
}
}
return false;
}

private void doPushDownIntoCTEProducerInternal(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
RuntimeFilterContext ctx, EqualTo equalTo, TRuntimeFilterType type, PhysicalCTEProducer cteProducer) {
PhysicalPlan inputPlanNode = (PhysicalPlan) cteProducer.child(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.google.common.base.Preconditions;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Lists;
import org.json.JSONArray;
import org.json.JSONObject;

Expand Down Expand Up @@ -191,4 +192,18 @@ public void setMutableState(String key, Object state) {
public int getId() {
return id.asInt();
}

/**
* ancestors in the tree
*/
public List<Plan> getAncestors() {
List<Plan> ancestors = Lists.newArrayList();
ancestors.add(this);
Optional<Object> parent = this.getMutableState(MutableState.KEY_PARENT);
while (parent.isPresent()) {
ancestors.add((Plan) parent.get());
parent = ((Plan) parent.get()).getMutableState(MutableState.KEY_PARENT);
}
return ancestors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,13 @@ public Slot getProducerSlot(Slot consumerSlot) {

@Override
public String toString() {
StringBuilder builder = new StringBuilder();
if (!getAppliedRuntimeFilters().isEmpty()) {
getAppliedRuntimeFilters()
.stream().forEach(rf -> builder.append(" RF").append(rf.getId().asInt()));
}
return Utils.toSqlString("PhysicalCTEConsumer[" + id.asInt() + "]",
"cteId", cteId);
"stats", getStats(), "cteId", cteId, "RFs", builder);
}

@Override
Expand Down Expand Up @@ -136,8 +141,15 @@ public PhysicalCTEConsumer withPhysicalPropertiesAndStats(

@Override
public String shapeInfo() {
return Utils.toSqlString("PhysicalCteConsumer",
"cteId", cteId);
StringBuilder shapeBuilder = new StringBuilder();
shapeBuilder.append(Utils.toSqlString("PhysicalCteConsumer",
"cteId", cteId));
if (!getAppliedRuntimeFilters().isEmpty()) {
shapeBuilder.append(" apply RFs:");
getAppliedRuntimeFilters()
.stream().forEach(rf -> shapeBuilder.append(" RF").append(rf.getId().asInt()));
}
return shapeBuilder.toString();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ public PhysicalDistribute(DistributionSpec spec, Optional<GroupExpression> group
@Override
public String toString() {
return Utils.toSqlString("PhysicalDistribute[" + id.asInt() + "]" + getGroupIdWithPrefix(),
"distributionSpec", distributionSpec,
"stats", statistics
"stats", statistics,
"distributionSpec", distributionSpec
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ public List<Expression> getExpressions() {
@Override
public String toString() {
return Utils.toSqlString("PhysicalFilter[" + id.asInt() + "]" + getGroupIdWithPrefix(),
"predicates", getPredicate(),
"stats", statistics
"stats", statistics,
"predicates", getPredicate()
);
}

Expand Down
Loading

0 comments on commit 0335396

Please sign in to comment.