Skip to content

Commit

Permalink
[refactor](nereids)make NormalizeAggregate rule more clear and readable
Browse files Browse the repository at this point in the history
  • Loading branch information
starocean999 committed Dec 19, 2023
1 parent 8c58bb6 commit b886ece
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -40,11 +41,9 @@
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -100,22 +99,74 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
@Override
public Rule build() {
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trival-agg for short
// This rule simplify LogicalAggregate node by:
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
// 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
// Push down exprs:
// 1. all group by exprs
// 2. subquery expr in trival-agg
// 3. all input slots of trival-agg
// 4. expr(including subquery) in distinct trival-agg
// Normalize LogicalAggregate's output.
// 1. normalize group by exprs by outputs of bottom LogicalProject
// 2. normalize trival-aggs by outputs of bottom LogicalProject
// 3. build normalized agg outputs
// Pull up exprs:
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
// 1. simple slots
// 2. aliases
// a. alias with no aggs child
// b. alias with trival-agg child
// c. alias with window-agg

List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
// Push down exprs:
// collect group by exprs
Set<Expression> groupingByExprs =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());

// collect all trival-agg
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));

// we need push down subquery exprs inside non-window and non-distinct agg functions
// collect subquery expr in trival-agg
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct()).collect(Collectors.toList()),
SubqueryExpr.class::isInstance);
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());

// collect all input slots of trival-agg
Set<Slot> allAggFuncInputSlots = aggFuncs.stream()
.flatMap(agg -> agg.getInputSlots().stream()).collect(Collectors.toSet());

// collect expr in distinct trival-agg
Set<Expression> distinctAggChildExprs = aggFuncs.stream()
.filter(agg -> agg.isDistinct()).flatMap(agg -> agg.children().stream())
.filter(child -> !(child instanceof SlotReference && child instanceof Literal))
.collect(Collectors.toSet());

Set<Alias> existsAlias =
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);

// push down group by and subquery exprs to bottom project
Set<Expression> allPushDownExprs =
Sets.union(Sets.union(groupingByExprs, subqueryExprs), distinctAggChildExprs);
NormalizeToSlotContext bottomSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs));
Set<NamedExpression> bottomOutputs =
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs));
NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs);
Set<NamedExpression> bottomProjects =
bottomSlotContext.pushDownToNamedExpression(allPushDownExprs);

// create bottom project
Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(
ImmutableList.copyOf(Sets.union(bottomProjects, allAggFuncInputSlots)),
aggregate.child());
} else {
bottomPlan = aggregate.child();
}

// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
Expand All @@ -127,89 +178,37 @@ public Rule build() {
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
// +-- project((a[#0] + 1)[#1])
List<AggregateFunction> normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomOutputs);
// TODO: if we have distinct agg, we must push down its children,
// because need use it to generate distribution enforce
// step 1: split agg functions into 2 parts: distinct and not distinct
List<AggregateFunction> distinctAggFuncs = Lists.newArrayList();
List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList();
for (AggregateFunction aggregateFunction : normalizedAggFuncs) {
if (aggregateFunction.isDistinct()) {
distinctAggFuncs.add(aggregateFunction);
} else {
nonDistinctAggFuncs.add(aggregateFunction);
}
}
// step 2: if we only have one distinct agg function, we do push down for it
if (!distinctAggFuncs.isEmpty()) {
// process distinct normalize and put it back to normalizedAggFuncs
List<AggregateFunction> newDistinctAggFuncs = Lists.newArrayList();
Map<Expression, Expression> replaceMap = Maps.newHashMap();
Map<Expression, NamedExpression> aliasCache = Maps.newHashMap();
for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : distinctAggFunc.children()) {
if (child instanceof SlotReference || child instanceof Literal) {
newChildren.add(child);
} else {
NamedExpression alias;
if (aliasCache.containsKey(child)) {
alias = aliasCache.get(child);
} else {
alias = new Alias(child);
aliasCache.put(child, alias);
}
bottomProjects.add(alias);
newChildren.add(alias.toSlot());
}
}
AggregateFunction newDistinctAggFunc = distinctAggFunc.withChildren(newChildren);
replaceMap.put(distinctAggFunc, newDistinctAggFunc);
newDistinctAggFuncs.add(newDistinctAggFunc);
}
aggregateOutput = aggregateOutput.stream()
.map(e -> ExpressionUtils.replace(e, replaceMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
distinctAggFuncs = newDistinctAggFuncs;
}
normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs);
normalizedAggFuncs.addAll(distinctAggFuncs);
// TODO: process redundant expressions in aggregate functions children

// normalize group by exprs by bottomProjects
List<Expression> normalizedGroupExprs =
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);

// normalize trival-aggs by bottomProjects
List<AggregateFunction> normalizedAggFuncs =
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);

// build normalized agg output
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
// agg output include 2 part, normalized group by slots and normalized agg functions

// agg output include 2 parts
// all bottom projects(group by exprs are included bottom projects) and normalized agg functions
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
.addAll(bottomProjects.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext
.pushDownToNamedExpression(normalizedAggFuncs))
.build();
// add normalized agg's input slots to bottom projects
Set<Slot> bottomProjectSlots = bottomProjects.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toSet());
Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(e -> !bottomProjectSlots.contains(e))
.collect(Collectors.toSet());
bottomProjects.addAll(aggInputSlots);
// build group by exprs
List<Expression> normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);

Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
} else {
bottomPlan = aggregate.child();
}
// create new agg node
LogicalAggregate newAggregate =
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);

// create upper projects by normalize all output exprs in old LogicalAggregate
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
bottomSlotContext, normalizedAggFuncsToSlotContext);

return new LogicalProject<>(upperProjects,
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan));
// create a parent project node
return new LogicalProject<>(upperProjects, newAggregate);
}).toRule(RuleType.NORMALIZE_AGGREGATE);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

suite("agg_distinct_case_when") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
sql "DROP TABLE IF EXISTS agg_test_table_t;"
sql """
CREATE TABLE `agg_test_table_t` (
`k1` varchar(65533) NULL,
`k2` text NULL,
`k3` text null,
`k4` text null
) ENGINE=OLAP
DUPLICATE KEY(`k1`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`k1`) BUCKETS 10
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"is_being_synced" = "false",
"storage_format" = "V2",
"light_schema_change" = "true",
"disable_auto_compaction" = "false",
"enable_single_replica_compaction" = "false"
);
"""

sql """insert into agg_test_table_t(`k1`,`k2`,`k3`) values('20231026221524','PA','adigu1bububud');"""
sql """
select
count(distinct case when t.k2='PA' and loan_date=to_date(substr(t.k1,1,8)) then t.k2 end )
from (
select substr(k1,1,8) loan_date,k3,k2,k1 from agg_test_table_t) t
group by
substr(t.k1,1,8);"""

sql "DROP TABLE IF EXISTS agg_test_table_t;"
}

0 comments on commit b886ece

Please sign in to comment.