Skip to content

Commit

Permalink
[fix](nereids) build agg for random distributed agg table in bindRela…
Browse files Browse the repository at this point in the history
…tion phase (#40181) (#40702)

pick from master #40181
  • Loading branch information
starocean999 authored Sep 12, 2024
1 parent e2dc754 commit 0f8176d
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 309 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.rules.analysis.BindSink;
import org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable;
import org.apache.doris.nereids.rules.analysis.CheckAfterBind;
import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
Expand Down Expand Up @@ -163,8 +162,6 @@ private static List<RewriteJob> buildAnalyzerJobs(Optional<CustomTableResolver>
topDown(new EliminateGroupByConstant()),

topDown(new SimplifyAggGroupBy()),
// run BuildAggForRandomDistributedTable before NormalizeAggregate in order to optimize the agg plan
topDown(new BuildAggForRandomDistributedTable()),
topDown(new NormalizeAggregate()),
topDown(new HavingToFilter()),
bottomUp(new SemiJoinCommute()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,6 @@ public enum RuleType {

// topn opts
DEFER_MATERIALIZE_TOP_N_RESULT(RuleTypeClass.REWRITE),
// pre agg for random distributed table
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN(RuleTypeClass.REWRITE),
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN(RuleTypeClass.REWRITE),
BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN(RuleTypeClass.REWRITE),
// short circuit rule
SHOR_CIRCUIT_POINT_QUERY(RuleTypeClass.REWRITE),
// exploration rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.catalog.AggStateType;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.DistributionInfo;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Partition;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.catalog.Type;
import org.apache.doris.catalog.View;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
Expand All @@ -44,13 +51,26 @@
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.algebra.Relation;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalEsScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
Expand All @@ -74,6 +94,7 @@
import com.google.common.collect.Sets;
import org.apache.commons.collections.CollectionUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
Expand Down Expand Up @@ -215,25 +236,127 @@ private LogicalPlan makeOlapScan(TableIf table, UnboundRelation unboundRelation,
unboundRelation.getTableSample());
}
}
if (!Util.showHiddenColumns() && scan.getTable().hasDeleteSign()
&& !ConnectContext.get().getSessionVariable().skipDeleteSign()) {
// table qualifier is catalog.db.table, we make db.table.column
Slot deleteSlot = null;
for (Slot slot : scan.getOutput()) {
if (slot.getName().equals(Column.DELETE_SIGN)) {
deleteSlot = slot;
if (needGenerateLogicalAggForRandomDistAggTable(scan)) {
// it's a random distribution agg table
// add agg on olap scan
return preAggForRandomDistribution(scan);
} else {
// it's a duplicate, unique or hash distribution agg table
// add delete sign filter on olap scan if needed
if (!Util.showHiddenColumns() && scan.getTable().hasDeleteSign()
&& !ConnectContext.get().getSessionVariable().skipDeleteSign()) {
// table qualifier is catalog.db.table, we make db.table.column
Slot deleteSlot = null;
for (Slot slot : scan.getOutput()) {
if (slot.getName().equals(Column.DELETE_SIGN)) {
deleteSlot = slot;
break;
}
}
Preconditions.checkArgument(deleteSlot != null);
Expression conjunct = new EqualTo(new TinyIntLiteral((byte) 0), deleteSlot);
if (!((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) {
scan = scan.withPreAggStatus(
PreAggStatus.off(Column.DELETE_SIGN + " is used as conjuncts."));
}
return new LogicalFilter<>(Sets.newHashSet(conjunct), scan);
}
return scan;
}
}

private boolean needGenerateLogicalAggForRandomDistAggTable(LogicalOlapScan olapScan) {
if (ConnectContext.get() != null && ConnectContext.get().getState() != null
&& ConnectContext.get().getState().isQuery()) {
// we only need to add an agg node for query, and should not do it for deleting
// from random distributed table. see https://github.com/apache/doris/pull/37985 for more info
OlapTable olapTable = olapScan.getTable();
KeysType keysType = olapTable.getKeysType();
DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo();
return keysType == KeysType.AGG_KEYS
&& distributionInfo.getType() == DistributionInfo.DistributionInfoType.RANDOM;
} else {
return false;
}
}

/**
* add LogicalAggregate above olapScan for preAgg
* @param olapScan olap scan plan
* @return rewritten plan
*/
private LogicalPlan preAggForRandomDistribution(LogicalOlapScan olapScan) {
OlapTable olapTable = olapScan.getTable();
List<Slot> childOutputSlots = olapScan.computeOutput();
List<Expression> groupByExpressions = new ArrayList<>();
List<NamedExpression> outputExpressions = new ArrayList<>();
List<Column> columns = olapTable.getBaseSchema();

for (Column col : columns) {
// use exist slot in the plan
SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.qualified());
ExprId exprId = slot.getExprId();
for (Slot childSlot : childOutputSlots) {
if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) {
exprId = childSlot.getExprId();
slot = slot.withExprId(exprId);
break;
}
}
Preconditions.checkArgument(deleteSlot != null);
Expression conjunct = new EqualTo(new TinyIntLiteral((byte) 0), deleteSlot);
if (!((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) {
scan = scan.withPreAggStatus(PreAggStatus.off(
Column.DELETE_SIGN + " is used as conjuncts."));
if (col.isKey()) {
groupByExpressions.add(slot);
outputExpressions.add(slot);
} else {
Expression function = generateAggFunction(slot, col);
// DO NOT rewrite
if (function == null) {
return olapScan;
}
Alias alias = new Alias(exprId, ImmutableList.of(function), col.getName(),
olapScan.qualified(), true);
outputExpressions.add(alias);
}
return new LogicalFilter<>(Sets.newHashSet(conjunct), scan);
}
return scan;
LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions,
olapScan);
return aggregate;
}

/**
* generate aggregation function according to the aggType of column
*
* @param slot slot of column
* @return aggFunction generated
*/
private Expression generateAggFunction(SlotReference slot, Column column) {
AggregateType aggregateType = column.getAggregationType();
switch (aggregateType) {
case SUM:
return new Sum(slot);
case MAX:
return new Max(slot);
case MIN:
return new Min(slot);
case HLL_UNION:
return new HllUnion(slot);
case BITMAP_UNION:
return new BitmapUnion(slot);
case QUANTILE_UNION:
return new QuantileUnion(slot);
case GENERIC:
Type type = column.getType();
if (!type.isAggStateType()) {
return null;
}
AggStateType aggState = (AggStateType) type;
// use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one
String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX;
FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot);
return builder.build(funcName, ImmutableList.of(slot)).first;
default:
return null;
}
}

private LogicalPlan getLogicalPlan(TableIf table, UnboundRelation unboundRelation,
Expand Down
Loading

0 comments on commit 0f8176d

Please sign in to comment.