Skip to content

Commit

Permalink
support limit->proj, support topn-agg
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed May 22, 2024
1 parent ec8782b commit 8c1ead2
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ public PlanFragment visitPhysicalDistribute(PhysicalDistribute<? extends Plan> d
&& context.getFirstAggregateInFragment(inputFragment) == child) {
PhysicalHashAggregate<?> hashAggregate = (PhysicalHashAggregate<?>) child;
if (hashAggregate.getAggPhase() == AggPhase.LOCAL
&& hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER) {
&& hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER
&& hashAggregate.getTopnPushInfo() == null) {
AggregationNode aggregationNode = (AggregationNode) inputFragment.getPlanRoot();
aggregationNode.setUseStreamingPreagg(hashAggregate.isMaybeUsingStream());
}
Expand Down Expand Up @@ -1035,6 +1036,23 @@ public PlanFragment visitPhysicalHashAggregate(
// local exchanger will be used.
aggregationNode.setColocate(true);
}
if (aggregate.getTopnPushInfo() != null) {
List<Expr> orderingExprs = Lists.newArrayList();
List<Boolean> ascOrders = Lists.newArrayList();
List<Boolean> nullsFirstParams = Lists.newArrayList();
aggregate.getTopnPushInfo().orderkeys.forEach(k -> {
orderingExprs.add(ExpressionTranslator.translate(k.getExpr(), context));
ascOrders.add(k.isAsc());
nullsFirstParams.add(k.isNullFirst());
});
SortInfo sortInfo = new SortInfo(orderingExprs, ascOrders, nullsFirstParams, outputTupleDesc);
aggregationNode.setSortByGroupKey(sortInfo);
if (aggregationNode.getLimit() == -1) {
aggregationNode.setLimit(aggregate.getTopnPushInfo().limit);
}
} else {
aggregationNode.setSortByGroupKey(null);
}
setPlanRoot(inputPlanFragment, aggregationNode, aggregate);
if (aggregate.getStats() != null) {
aggregationNode.setCardinality((long) aggregate.getStats().getRowCount());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ public List<PlanPostProcessor> getProcessors() {
builder.add(new RecomputeLogicalPropertiesProcessor());
builder.add(new AddOffsetIntoDistribute());
builder.add(new CommonSubExpressionOpt());
if (cascadesContext.getConnectContext().getSessionVariable().pushLimitToLocalAgg) {
builder.add(new PushLimitToLocalAgg());
}
// DO NOT replace PLAN NODE from here
builder.add(new TopNScanOpt());
builder.add(new FragmentProcessor());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/apache/impala/blob/branch-2.9.0/fe/src/main/java/org/apache/impala/AggregationNode.java
// and modified by Doris

package org.apache.doris.nereids.processor.post;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopnPushInfo;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;

import org.apache.hadoop.util.Lists;

import java.util.List;
import java.util.stream.Collectors;

/**
Pattern1:
limit(n) -> aggGlobal -> distribute -> aggLocal
=>
limit(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
Pattern2: topn orderkeys are the prefix of group keys
topn -> aggGlobal -> distribute -> aggLocal
=>
topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
*/
public class PushLimitToLocalAgg extends PlanPostProcessor {
@Override
public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) {
Plan topnChild = topN.child();
if (topnChild instanceof PhysicalProject) {
topnChild = topnChild.child(0);
}
if (topnChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild;
if (upperAgg.getAggPhase().isGlobal()
&& upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
List<OrderKey> orderKeys = tryGenerateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg);
if (!orderKeys.isEmpty()) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getLimit() + topN.getOffset()));
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
bottomAgg.setTopnPushInfo(new TopnPushInfo(
orderKeys,
topN.getLimit() + topN.getOffset()));
}
}
}
topN.child().accept(this, ctx);
return topN;
}

@Override
public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) {
Plan limitChild = limit.child();
if (limitChild instanceof PhysicalProject) {
limitChild = limitChild.child(0);
}
if (limitChild instanceof PhysicalHashAggregate) {
PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild;
if (upperAgg.getAggPhase().isGlobal()
&& upperAgg.child() instanceof PhysicalDistribute
&& upperAgg.child().child(0) instanceof PhysicalHashAggregate) {
upperAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(upperAgg),
limit.getLimit() + limit.getOffset()));
PhysicalHashAggregate<? extends Plan> bottomAgg =
(PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0);
bottomAgg.setTopnPushInfo(new TopnPushInfo(
generateOrderKeyByGroupKey(bottomAgg),
limit.getLimit() + limit.getOffset()));
}
}
limit.child().accept(this, ctx);

return limit;
}

private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? extends Plan> agg) {
return agg.getGroupByExpressions().stream()
.map(key -> new OrderKey(key, true, false))
.collect(Collectors.toList());
}

/**
return true, if topn order-key is prefix of agg group-key, ignore asc/desc and null_first
TODO order-key can be subset of group-key. BE does not support now.
*/
private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(PhysicalTopN<? extends Plan> topN,
PhysicalHashAggregate<? extends Plan> agg) {
List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) {
return orderKeys;
}
List<Expression> topnKeys = topN.getOrderKeys().stream()
.map(OrderKey::getExpr).collect(Collectors.toList());
for (int i = 0; i < topN.getOrderKeys().size(); i++) {
// prefix check
if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) {
return Lists.newArrayList();
}
orderKeys.add(topN.getOrderKeys().get(i));
}
for (int i = topN.getOrderKeys().size(); i < agg.getGroupByExpressions().size(); i++) {
orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false));
}
return orderKeys;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.RequireProperties;
import org.apache.doris.nereids.properties.RequirePropertiesSupplier;
Expand Down Expand Up @@ -60,6 +61,9 @@ public class PhysicalHashAggregate<CHILD_TYPE extends Plan> extends PhysicalUnar

private final RequireProperties requireProperties;

// only used in post processor
private TopnPushInfo topnPushInfo = null;

public PhysicalHashAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
AggregateParam aggregateParam, boolean maybeUsingStream, LogicalProperties logicalProperties,
RequireProperties requireProperties, CHILD_TYPE child) {
Expand Down Expand Up @@ -196,6 +200,7 @@ public String toString() {
"outputExpr", outputExpressions,
"partitionExpr", partitionExpressions,
"requireProperties", requireProperties,
"topnOpt", topnPushInfo != null,
"stats", statistics
);
}
Expand Down Expand Up @@ -231,19 +236,22 @@ public PhysicalHashAggregate<Plan> withChildren(List<Plan> children) {
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(),
requireProperties, physicalProperties, statistics,
children.get(0));
children.get(0))
.setTopnPushInfo(topnPushInfo);
}

public PhysicalHashAggregate<CHILD_TYPE> withPartitionExpressions(List<Expression> partitionExpressions) {
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions,
Optional.ofNullable(partitionExpressions), aggregateParam, maybeUsingStream,
Optional.empty(), getLogicalProperties(), requireProperties, child());
Optional.empty(), getLogicalProperties(), requireProperties, child())
.setTopnPushInfo(topnPushInfo);
}

@Override
public PhysicalHashAggregate<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(), requireProperties, child());
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(), requireProperties, child())
.setTopnPushInfo(topnPushInfo);
}

@Override
Expand All @@ -252,7 +260,7 @@ public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpr
Preconditions.checkArgument(children.size() == 1);
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, logicalProperties.get(),
requireProperties, children.get(0));
requireProperties, children.get(0)).setTopnPushInfo(topnPushInfo);
}

@Override
Expand All @@ -261,21 +269,21 @@ public PhysicalHashAggregate<CHILD_TYPE> withPhysicalPropertiesAndStats(Physical
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, groupExpression, getLogicalProperties(),
requireProperties, physicalProperties, statistics,
child());
child()).setTopnPushInfo(topnPushInfo);
}

@Override
public PhysicalHashAggregate<CHILD_TYPE> withAggOutput(List<NamedExpression> newOutput) {
return new PhysicalHashAggregate<>(groupByExpressions, newOutput, partitionExpressions,
aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(),
requireProperties, physicalProperties, statistics, child());
requireProperties, physicalProperties, statistics, child()).setTopnPushInfo(topnPushInfo);
}

public <C extends Plan> PhysicalHashAggregate<C> withRequirePropertiesAndChild(
RequireProperties requireProperties, C newChild) {
return new PhysicalHashAggregate<>(groupByExpressions, outputExpressions, partitionExpressions,
aggregateParam, maybeUsingStream, Optional.empty(), getLogicalProperties(),
requireProperties, physicalProperties, statistics, newChild);
requireProperties, physicalProperties, statistics, newChild).setTopnPushInfo(topnPushInfo);
}

@Override
Expand All @@ -299,4 +307,26 @@ public PhysicalHashAggregate<CHILD_TYPE> resetLogicalProperties() {
requireProperties, physicalProperties, statistics,
child());
}

/**
* used to push limit down to localAgg
*/
public static class TopnPushInfo {
public List<OrderKey> orderkeys;
public long limit;

public TopnPushInfo(List<OrderKey> orderkeys, long limit) {
this.orderkeys = ImmutableList.copyOf(orderkeys);
this.limit = limit;
}
}

public TopnPushInfo getTopnPushInfo() {
return topnPushInfo;
}

public PhysicalHashAggregate<CHILD_TYPE> setTopnPushInfo(TopnPushInfo topnPushInfo) {
this.topnPushInfo = topnPushInfo;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.SlotDescriptor;
import org.apache.doris.analysis.SlotId;
import org.apache.doris.analysis.SortInfo;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.common.NotImplementedException;
import org.apache.doris.common.UserException;
Expand Down Expand Up @@ -65,6 +66,8 @@ public class AggregationNode extends PlanNode {
// If true, use streaming preaggregation algorithm. Not valid if this is a merge agg.
private boolean useStreamingPreagg;

private SortInfo sortByGroupKey;

/**
* Create an agg node that is not an intermediate node.
* isIntermediate is true if it is a slave node in a 2-part agg plan.
Expand Down Expand Up @@ -288,6 +291,9 @@ protected void toThrift(TPlanNode msg) {
msg.agg_node.setUseStreamingPreaggregation(useStreamingPreagg);
msg.agg_node.setIsFirstPhase(aggInfo.isFirstPhase());
msg.agg_node.setIsColocate(isColocate);
if (sortByGroupKey != null) {
msg.agg_node.setAggSortInfoByGroupKey(sortByGroupKey.toThrift());
}
List<Expr> groupingExprs = aggInfo.getGroupingExprs();
if (groupingExprs != null) {
msg.agg_node.setGroupingExprs(Expr.treesToThrift(groupingExprs));
Expand Down Expand Up @@ -333,6 +339,7 @@ public String getNodeExplainString(String detailPrefix, TExplainLevel detailLeve
if (!conjuncts.isEmpty()) {
output.append(detailPrefix).append("having: ").append(getExplainString(conjuncts)).append("\n");
}
output.append(detailPrefix).append("sortByGroupKey:").append(sortByGroupKey != null).append("\n");
output.append(detailPrefix).append(String.format(
"cardinality=%,d", cardinality)).append("\n");
return output.toString();
Expand Down Expand Up @@ -411,4 +418,13 @@ public void finalize(Analyzer analyzer) throws UserException {
public void setColocate(boolean colocate) {
isColocate = colocate;
}


public boolean isSortByGroupKey() {
return sortByGroupKey != null;
}

public void setSortByGroupKey(SortInfo sortByGroupKey) {
this.sortByGroupKey = sortByGroupKey;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,9 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) {
@VariableMgr.VarAttr(name = REWRITE_OR_TO_IN_PREDICATE_THRESHOLD, fuzzy = true)
private int rewriteOrToInPredicateThreshold = 2;

@VariableMgr.VarAttr(name = "push_limit_to_local_agg", fuzzy = false, needForward = true)
public boolean pushLimitToLocalAgg = true;

@VariableMgr.VarAttr(name = NEREIDS_CBO_PENALTY_FACTOR, needForward = true)
private double nereidsCboPenaltyFactor = 0.7;

Expand Down
2 changes: 1 addition & 1 deletion gensrc/thrift/PlanNodes.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ struct TAggregationNode {
7: optional list<TSortInfo> agg_sort_infos
8: optional bool is_first_phase
9: optional bool is_colocate
// 9: optional bool use_fixed_length_serialization_opt
10: optional TSortInfo agg_sort_info_by_group_key
}

struct TRepeatNode {
Expand Down
Loading

0 comments on commit 8c1ead2

Please sign in to comment.