Skip to content

Commit

Permalink
[refactor](nereids) make NormalizeAggregate rule more clear and reada…
Browse files Browse the repository at this point in the history
…ble (apache#28607)
  • Loading branch information
starocean999 authored and stephen committed Dec 28, 2023
1 parent 668321c commit 7b8da5b
Show file tree
Hide file tree
Showing 13 changed files with 438 additions and 374 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot;
import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand All @@ -40,9 +41,9 @@
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -100,22 +101,94 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali
@Override
public Rule build() {
return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
// The LogicalAggregate node may contain window agg functions and usual agg functions
// we call window agg functions as window-agg and usual agg functions as trival-agg for short
// This rule simplify LogicalAggregate node by:
// 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node,
// 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs
// 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node
// Push down exprs:
// 1. all group by exprs
// 2. child contains subquery expr in trival-agg
// 3. child contains window expr in trival-agg
// 4. all input slots of trival-agg
// 5. expr(including subquery) in distinct trival-agg
// Normalize LogicalAggregate's output.
// 1. normalize group by exprs by outputs of bottom LogicalProject
// 2. normalize trival-aggs by outputs of bottom LogicalProject
// 3. build normalized agg outputs
// Pull up exprs:
// normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes:
// 1. simple slots
// 2. aliases
// a. alias with no aggs child
// b. alias with trival-agg child
// c. alias with window-agg

List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
Set<Alias> existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
// Push down exprs:
// collect group by exprs
Set<Expression> groupingByExprs =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());

// collect all trival-agg
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));

// we need push down subquery exprs inside non-window and non-distinct agg functions
Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct()).collect(Collectors.toList()),
SubqueryExpr.class::isInstance);
Set<Expression> groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions());
// split non-distinct agg child as two part
// TRUE part 1: need push down itself, if it contains subqury or window expression
// FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> !aggFunc.isDistinct())
.flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
Collectors.toSet()));

// split distinct agg child as two parts
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
.filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream())
.collect(Collectors.groupingBy(
child -> !(child instanceof SlotReference || child instanceof Literal),
Collectors.toSet()));

Set<Expression> needPushSelf = Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>()));
Set<Slot> needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union(
categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()),
categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>())));

Set<Alias> existsAlias =
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);

// push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later
// 1. group by exprs
// 2. trivalAgg children
// 3. trivalAgg input slots
Set<Expression> allPushDownExprs =
Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots));
NormalizeToSlotContext bottomSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs));
Set<NamedExpression> bottomOutputs =
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs));
NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs);
Set<NamedExpression> pushedGroupByExprs =
bottomSlotContext.pushDownToNamedExpression(groupingByExprs);
Set<NamedExpression> pushedTrivalAggChildren =
bottomSlotContext.pushDownToNamedExpression(needPushSelf);
Set<NamedExpression> pushedTrivalAggInputSlots =
bottomSlotContext.pushDownToNamedExpression(needPushInputSlots);
Set<NamedExpression> bottomProjects = Sets.union(pushedGroupByExprs,
Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots));

// create bottom project
Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects),
aggregate.child());
} else {
bottomPlan = aggregate.child();
}

// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
Expand All @@ -127,89 +200,37 @@ public Rule build() {
// after normalize:
// agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1])
// +-- project((a[#0] + 1)[#1])
List<AggregateFunction> normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs);
Set<NamedExpression> bottomProjects = Sets.newHashSet(bottomOutputs);
// TODO: if we have distinct agg, we must push down its children,
// because need use it to generate distribution enforce
// step 1: split agg functions into 2 parts: distinct and not distinct
List<AggregateFunction> distinctAggFuncs = Lists.newArrayList();
List<AggregateFunction> nonDistinctAggFuncs = Lists.newArrayList();
for (AggregateFunction aggregateFunction : normalizedAggFuncs) {
if (aggregateFunction.isDistinct()) {
distinctAggFuncs.add(aggregateFunction);
} else {
nonDistinctAggFuncs.add(aggregateFunction);
}
}
// step 2: if we only have one distinct agg function, we do push down for it
if (!distinctAggFuncs.isEmpty()) {
// process distinct normalize and put it back to normalizedAggFuncs
List<AggregateFunction> newDistinctAggFuncs = Lists.newArrayList();
Map<Expression, Expression> replaceMap = Maps.newHashMap();
Map<Expression, NamedExpression> aliasCache = Maps.newHashMap();
for (AggregateFunction distinctAggFunc : distinctAggFuncs) {
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : distinctAggFunc.children()) {
if (child instanceof SlotReference || child instanceof Literal) {
newChildren.add(child);
} else {
NamedExpression alias;
if (aliasCache.containsKey(child)) {
alias = aliasCache.get(child);
} else {
alias = new Alias(child);
aliasCache.put(child, alias);
}
bottomProjects.add(alias);
newChildren.add(alias.toSlot());
}
}
AggregateFunction newDistinctAggFunc = distinctAggFunc.withChildren(newChildren);
replaceMap.put(distinctAggFunc, newDistinctAggFunc);
newDistinctAggFuncs.add(newDistinctAggFunc);
}
aggregateOutput = aggregateOutput.stream()
.map(e -> ExpressionUtils.replace(e, replaceMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
distinctAggFuncs = newDistinctAggFuncs;
}
normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs);
normalizedAggFuncs.addAll(distinctAggFuncs);
// TODO: process redundant expressions in aggregate functions children

// normalize group by exprs by bottomProjects
List<Expression> normalizedGroupExprs =
bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);

// normalize trival-aggs by bottomProjects
List<AggregateFunction> normalizedAggFuncs =
bottomSlotContext.normalizeToUseSlotRef(aggFuncs);

// build normalized agg output
NormalizeToSlotContext normalizedAggFuncsToSlotContext =
NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs);
// agg output include 2 part, normalized group by slots and normalized agg functions

// agg output include 2 parts
// pushedGroupByExprs and normalized agg functions
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs))
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
.addAll(normalizedAggFuncsToSlotContext
.pushDownToNamedExpression(normalizedAggFuncs))
.build();
// add normalized agg's input slots to bottom projects
Set<Slot> bottomProjectSlots = bottomProjects.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toSet());
Set<NamedExpression> aggInputSlots = normalizedAggFuncs.stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(e -> !bottomProjectSlots.contains(e))
.collect(Collectors.toSet());
bottomProjects.addAll(aggInputSlots);
// build group by exprs
List<Expression> normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs);

Plan bottomPlan;
if (!bottomProjects.isEmpty()) {
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
} else {
bottomPlan = aggregate.child();
}
// create new agg node
LogicalAggregate newAggregate =
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);

// create upper projects by normalize all output exprs in old LogicalAggregate
List<NamedExpression> upperProjects = normalizeOutput(aggregateOutput,
bottomSlotContext, normalizedAggFuncsToSlotContext);

return new LogicalProject<>(upperProjects,
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan));
// create a parent project node
return new LogicalProject<>(upperProjects, newAggregate);
}).toRule(RuleType.NORMALIZE_AGGREGATE);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,37 +3,36 @@
PhysicalResultSink
--PhysicalTopN[MERGE_SORT]
----PhysicalTopN[LOCAL_SORT]
------PhysicalProject
--------hashAgg[DISTINCT_GLOBAL]
----------PhysicalDistribute
------------hashAgg[DISTINCT_LOCAL]
--------------hashAgg[GLOBAL]
----------------hashAgg[LOCAL]
------------------PhysicalProject
--------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((cs1.cs_order_number = cs2.cs_order_number)) otherCondition=(( not (cs_warehouse_sk = cs_warehouse_sk))) build RFs:RF4 cs_order_number->[cs_order_number]
------hashAgg[DISTINCT_GLOBAL]
--------PhysicalDistribute
----------hashAgg[DISTINCT_LOCAL]
------------hashAgg[GLOBAL]
--------------hashAgg[LOCAL]
----------------PhysicalProject
------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((cs1.cs_order_number = cs2.cs_order_number)) otherCondition=(( not (cs_warehouse_sk = cs_warehouse_sk))) build RFs:RF4 cs_order_number->[cs_order_number]
--------------------PhysicalDistribute
----------------------PhysicalProject
------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF4
--------------------hashJoin[RIGHT_ANTI_JOIN] hashCondition=((cs1.cs_order_number = cr1.cr_order_number)) otherCondition=() build RFs:RF3 cs_order_number->[cr_order_number]
----------------------PhysicalDistribute
------------------------PhysicalProject
--------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF4
----------------------hashJoin[RIGHT_ANTI_JOIN] hashCondition=((cs1.cs_order_number = cr1.cr_order_number)) otherCondition=() build RFs:RF3 cs_order_number->[cr_order_number]
------------------------PhysicalDistribute
--------------------------PhysicalProject
----------------------------PhysicalOlapScan[catalog_returns] apply RFs: RF3
------------------------PhysicalDistribute
--------------------------hashJoin[INNER_JOIN] hashCondition=((cs1.cs_call_center_sk = call_center.cc_call_center_sk)) otherCondition=() build RFs:RF2 cc_call_center_sk->[cs_call_center_sk]
----------------------------hashJoin[INNER_JOIN] hashCondition=((cs1.cs_ship_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[cs_ship_date_sk]
------------------------------hashJoin[INNER_JOIN] hashCondition=((cs1.cs_ship_addr_sk = customer_address.ca_address_sk)) otherCondition=() build RFs:RF0 ca_address_sk->[cs_ship_addr_sk]
--------------------------------PhysicalProject
----------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1 RF2
--------------------------------PhysicalDistribute
----------------------------------PhysicalProject
------------------------------------filter((customer_address.ca_state = 'PA'))
--------------------------------------PhysicalOlapScan[customer_address]
--------------------------PhysicalOlapScan[catalog_returns] apply RFs: RF3
----------------------PhysicalDistribute
------------------------hashJoin[INNER_JOIN] hashCondition=((cs1.cs_call_center_sk = call_center.cc_call_center_sk)) otherCondition=() build RFs:RF2 cc_call_center_sk->[cs_call_center_sk]
--------------------------hashJoin[INNER_JOIN] hashCondition=((cs1.cs_ship_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[cs_ship_date_sk]
----------------------------hashJoin[INNER_JOIN] hashCondition=((cs1.cs_ship_addr_sk = customer_address.ca_address_sk)) otherCondition=() build RFs:RF0 ca_address_sk->[cs_ship_addr_sk]
------------------------------PhysicalProject
--------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF0 RF1 RF2
------------------------------PhysicalDistribute
--------------------------------PhysicalProject
----------------------------------filter((date_dim.d_date <= '2002-05-31') and (date_dim.d_date >= '2002-04-01'))
------------------------------------PhysicalOlapScan[date_dim]
----------------------------------filter((customer_address.ca_state = 'PA'))
------------------------------------PhysicalOlapScan[customer_address]
----------------------------PhysicalDistribute
------------------------------PhysicalProject
--------------------------------filter((call_center.cc_county = 'Williamson County'))
----------------------------------PhysicalOlapScan[call_center]
--------------------------------filter((date_dim.d_date <= '2002-05-31') and (date_dim.d_date >= '2002-04-01'))
----------------------------------PhysicalOlapScan[date_dim]
--------------------------PhysicalDistribute
----------------------------PhysicalProject
------------------------------filter((call_center.cc_county = 'Williamson County'))
--------------------------------PhysicalOlapScan[call_center]

Loading

0 comments on commit 7b8da5b

Please sign in to comment.