Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pick](nereids) push down runtime filter to all children of SetOperation Branch 2.1 #33778

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,10 @@ public PhysicalPlan visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? ext
|| (pair != null && pair.first instanceof PhysicalCTEConsumer)) {
continue;
}
join.pushDownRuntimeFilter(context, generator, join, equalTo.right(),
equalTo.left(), type, buildSideNdv, i);
if (equalTo.left().getInputSlots().size() == 1) {
join.pushDownRuntimeFilter(context, generator, join, equalTo.right(),
equalTo.left(), type, buildSideNdv, i);
}
}
}
return join;
Expand Down Expand Up @@ -336,7 +338,7 @@ private void generateBitMapRuntimeFilterForNLJ(PhysicalNestedLoopJoin<? extends
TRuntimeFilterType type = TRuntimeFilterType.BITMAP;
Set<Slot> targetSlots = bitmapContains.child(1).getInputSlots();
for (Slot targetSlot : targetSlots) {
if (!checkPushDownPreconditionsForJoin(join, ctx, targetSlot)) {
if (!checkProbeSlot(ctx, targetSlot)) {
continue;
}
Slot scanSlot = ctx.getAliasTransferPair(targetSlot).second;
Expand Down Expand Up @@ -520,7 +522,7 @@ private boolean doPushDownIntoCTEProducerInternal(RuntimeFilter rf, Expression t
Slot unwrappedSlot = checkTargetChild(targetExpression);
// aliasTransMap doesn't contain the key, means that the path from the scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (!checkPushDownPreconditionsForJoin(rf.getBuilderNode(), ctx, unwrappedSlot)) {
if (!checkProbeSlot(ctx, unwrappedSlot)) {
return false;
}
Slot cteSlot = ctx.getAliasTransferPair(unwrappedSlot).second;
Expand Down Expand Up @@ -605,17 +607,13 @@ public static boolean checkPushDownPreconditionsForProjectOrDistribute(RuntimeFi
}

/**
* Check runtime filter push down pre-conditions, such as builder side join type, etc.
* check if slot is in ctx.aliasTransferMap
*/
public static boolean checkPushDownPreconditionsForJoin(AbstractPhysicalJoin physicalJoin,
RuntimeFilterContext ctx, Slot slot) {
public static boolean checkProbeSlot(RuntimeFilterContext ctx, Slot slot) {
if (slot == null || !ctx.aliasTransferMapContains(slot)) {
return false;
} else if (DENIED_JOIN_TYPES.contains(physicalJoin.getJoinType()) || physicalJoin.isMarkJoin()) {
return false;
} else {
return true;
}
return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;

import java.util.Map;

/**
* This is the factory for all ExpressionVisitor instance.
* All children instance of DefaultExpressionVisitor or ExpressionVisitor for common usage
Expand All @@ -29,6 +31,7 @@
public class ExpressionVisitors {

public static final ContainsAggregateChecker CONTAINS_AGGREGATE_CHECKER = new ContainsAggregateChecker();
public static final ExpressionMapReplacer EXPRESSION_MAP_REPLACER = new ExpressionMapReplacer();

private static class ContainsAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
@Override
Expand All @@ -54,4 +57,22 @@ public Boolean visitAggregateFunction(AggregateFunction aggregateFunction, Void
return true;
}
}

/**
* replace sub expr by Map
*/
public static class ExpressionMapReplacer
extends DefaultExpressionRewriter<Map<Expression, Expression>> {

private ExpressionMapReplacer() {
}

@Override
public Expression visit(Expression expr, Map<Expression, Expression> replaceMap) {
if (replaceMap.containsKey(expr)) {
return replaceMap.get(expr);
}
return super.visit(expr, replaceMap);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator<Runtim

// aliasTransMap doesn't contain the key, means that the path from the scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, probeSlot)) {
if (!RuntimeFilterGenerator.checkProbeSlot(ctx, probeSlot)) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator<Runtim

// aliasTransMap doesn't contain the key, means that the path from the scan to the join
// contains join with denied join type. for example: a left join b on a.id = b.id
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, probeSlot)) {
if (!RuntimeFilterGenerator.checkProbeSlot(ctx, probeSlot)) {
return false;
}
PhysicalRelation scan = ctx.getAliasTransferPair(probeSlot).first;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator<Runtim
}
}
Slot newProbeSlot = RuntimeFilterGenerator.checkTargetChild(newProbeExpr);
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, newProbeSlot)) {
if (!RuntimeFilterGenerator.checkProbeSlot(ctx, newProbeSlot)) {
return false;
}
scan = ctx.getAliasTransferPair(newProbeSlot).first;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@
import org.apache.doris.common.IdGenerator;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
Expand All @@ -38,8 +37,10 @@
import org.apache.doris.thrift.TRuntimeFilterType;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

Expand Down Expand Up @@ -148,42 +149,31 @@ public int getArity() {
public boolean pushDownRuntimeFilter(CascadesContext context, IdGenerator<RuntimeFilterId> generator,
AbstractPhysicalJoin<?, ?> builderNode, Expression src, Expression probeExpr,
TRuntimeFilterType type, long buildSideNdv, int exprOrder) {
RuntimeFilterContext ctx = context.getRuntimeFilterContext();
boolean pushedDown = false;
int projIndex = -1;
Slot probeSlot = RuntimeFilterGenerator.checkTargetChild(probeExpr);
if (probeSlot == null) {
return false;
}
List<NamedExpression> output = getOutputs();
for (int j = 0; j < output.size(); j++) {
NamedExpression expr = output.get(j);
if (expr.getName().equals(probeSlot.getName())) {
projIndex = j;
break;
}
}
if (projIndex == -1) {
return false;
}
for (int i = 0; i < this.children().size(); i++) {
Map<Expression, Expression> map = Maps.newHashMap();
// probeExpr only has one input slot
map.put(probeExpr.getInputSlots().iterator().next(), regularChildrenOutputs.get(i).get(projIndex));
Expression newProbeExpr = probeExpr.accept(ExpressionVisitors.EXPRESSION_MAP_REPLACER, map);
AbstractPhysicalPlan child = (AbstractPhysicalPlan) this.child(i);
// TODO: replace this special logic with dynamic handling and the name matching
if (child instanceof PhysicalDistribute) {
child = (AbstractPhysicalPlan) child.child(0);
}
if (child instanceof PhysicalProject) {
PhysicalProject<?> project = (PhysicalProject<?>) child;
int projIndex = -1;
Slot probeSlot = RuntimeFilterGenerator.checkTargetChild(probeExpr);
if (probeSlot == null) {
continue;
}
for (int j = 0; j < project.getProjects().size(); j++) {
NamedExpression expr = project.getProjects().get(j);
if (expr.getName().equals(probeSlot.getName())) {
projIndex = j;
break;
}
}
if (projIndex < 0 || projIndex >= project.getProjects().size()) {
continue;
}
Expression newProbeExpr = project.getProjects().get(projIndex);
if (newProbeExpr instanceof Alias) {
newProbeExpr = newProbeExpr.child(0);
}
Slot newProbeSlot = RuntimeFilterGenerator.checkTargetChild(newProbeExpr);
if (!RuntimeFilterGenerator.checkPushDownPreconditionsForJoin(builderNode, ctx, newProbeSlot)) {
continue;
}
pushedDown |= child.pushDownRuntimeFilter(context, generator, builderNode, src,
pushedDown |= child.pushDownRuntimeFilter(context, generator, builderNode, src,
newProbeExpr, type, buildSideNdv, exprOrder);
}
}
return pushedDown;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------------PhysicalProject
--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = d1.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ss_sold_date_sk]
----------------PhysicalProject
------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = iss.i_item_sk)) otherCondition=()
------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_item_sk = iss.i_item_sk)) otherCondition=() build RFs:RF0 i_item_sk->[ss_item_sk]
--------------------PhysicalProject
----------------------PhysicalOlapScan[store_sales] apply RFs: RF1
----------------------PhysicalOlapScan[store_sales] apply RFs: RF0 RF1
--------------------PhysicalDistribute[DistributionSpecReplicated]
----------------------PhysicalProject
------------------------PhysicalOlapScan[item]
Expand All @@ -23,9 +23,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------------PhysicalProject
--------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = d2.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[cs_sold_date_sk]
----------------PhysicalProject
------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = ics.i_item_sk)) otherCondition=()
------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_item_sk = ics.i_item_sk)) otherCondition=() build RFs:RF2 i_item_sk->[cs_item_sk]
--------------------PhysicalProject
----------------------PhysicalOlapScan[catalog_sales] apply RFs: RF3
----------------------PhysicalOlapScan[catalog_sales] apply RFs: RF2 RF3
--------------------PhysicalDistribute[DistributionSpecReplicated]
----------------------PhysicalProject
------------------------PhysicalOlapScan[item]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------PhysicalProject
------------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001))
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )

Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------PhysicalProject
------------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999))
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalProject
------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 i_brand_id->[i_brand_id];RF7 i_class_id->[i_class_id];RF8 i_category_id->[i_category_id]
------hashJoin[INNER_JOIN] hashCondition=((item.i_brand_id = t.brand_id) and (item.i_category_id = t.category_id) and (item.i_class_id = t.class_id)) otherCondition=() build RFs:RF6 i_brand_id->[i_brand_id,i_brand_id,i_brand_id];RF7 i_class_id->[i_class_id,i_class_id,i_class_id];RF8 i_category_id->[i_category_id,i_category_id,i_category_id]
--------PhysicalIntersect
----------PhysicalDistribute[DistributionSpecHash]
------------PhysicalProject
Expand All @@ -28,7 +28,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------PhysicalOlapScan[catalog_sales] apply RFs: RF2 RF3
--------------------PhysicalDistribute[DistributionSpecReplicated]
----------------------PhysicalProject
------------------------PhysicalOlapScan[item]
------------------------PhysicalOlapScan[item] apply RFs: RF6 RF7 RF8
----------------PhysicalDistribute[DistributionSpecReplicated]
------------------PhysicalProject
--------------------filter((d2.d_year <= 2002) and (d2.d_year >= 2000))
Expand All @@ -42,7 +42,7 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------PhysicalOlapScan[web_sales] apply RFs: RF4 RF5
--------------------PhysicalDistribute[DistributionSpecReplicated]
----------------------PhysicalProject
------------------------PhysicalOlapScan[item]
------------------------PhysicalOlapScan[item] apply RFs: RF6 RF7 RF8
----------------PhysicalDistribute[DistributionSpecReplicated]
------------------PhysicalProject
--------------------filter((d3.d_year <= 2002) and (d3.d_year >= 2000))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------PhysicalProject
------------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 2001))
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )

Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
----------------------PhysicalProject
------------------------filter((if((avg_monthly_sales > 0.0000), (cast(abs((sum_sales - cast(avg_monthly_sales as DECIMALV3(38, 2)))) as DECIMALV3(38, 10)) / avg_monthly_sales), NULL) > 0.100000) and (v2.avg_monthly_sales > 0.0000) and (v2.d_year = 1999))
--------------------------PhysicalCteConsumer ( cteId=CTEId#0 )

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !rf_setop --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalProject
----------hashJoin[INNER_JOIN] hashCondition=((T.l_linenumber = expr_cast(r_regionkey as BIGINT))) otherCondition=() build RFs:RF0 expr_cast(r_regionkey as BIGINT)->[cast(l_linenumber as BIGINT),o_orderkey]
------------PhysicalExcept
--------------PhysicalProject
----------------hashAgg[GLOBAL]
------------------PhysicalDistribute[DistributionSpecHash]
--------------------hashAgg[LOCAL]
----------------------PhysicalProject
------------------------PhysicalOlapScan[lineitem] apply RFs: RF0
--------------PhysicalDistribute[DistributionSpecHash]
----------------PhysicalProject
------------------hashAgg[LOCAL]
--------------------PhysicalProject
----------------------PhysicalOlapScan[orders] apply RFs: RF0
------------PhysicalDistribute[DistributionSpecHash]
--------------PhysicalProject
----------------PhysicalOlapScan[region]

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

suite("test_pushdown_setop") {
String db = context.config.getDbNameByFile(new File(context.file.parent))
sql "use ${db}"
sql 'set enable_nereids_planner=true'
sql 'set enable_fallback_to_original_planner=false'
sql 'set exec_mem_limit=21G'
sql 'set be_number_for_test=3'
sql 'set parallel_fragment_exec_instance_num=8; '
sql 'set parallel_pipeline_task_num=8; '
sql 'set forbid_unknown_col_stats=true'
sql 'set enable_nereids_timeout = false'
sql 'set enable_runtime_filter_prune=false'
sql 'set runtime_filter_type=8'
def query = """ select count() from ((select l_linenumber from lineitem) except (select o_orderkey from orders)) T join region on T.l_linenumber = r_regionkey;"""
qt_rf_setop """
explain shape plan
${query}
"""
}

Loading