From b886ece1651d7f824dad304392a431dfd0b65baa Mon Sep 17 00:00:00 2001 From: starocean999 <12095047@qq.com> Date: Tue, 19 Dec 2023 09:59:22 +0800 Subject: [PATCH] [refactor](nereids)make NormalizeAggregate rule more clear and readable --- .../rules/analysis/NormalizeAggregate.java | 163 +++++++++--------- .../aggregate/agg_distinct_case_when.groovy | 54 ++++++ 2 files changed, 135 insertions(+), 82 deletions(-) create mode 100644 regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index d265c3d8d408303..577d4c6fbf53ba9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot; +import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; @@ -40,11 +41,9 @@ import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; -import com.google.common.collect.Maps; import com.google.common.collect.Sets; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -100,22 +99,74 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali @Override public Rule build() { return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> { + // The LogicalAggregate node may contain window agg functions and usual agg functions + // we call window agg functions as window-agg and usual agg functions as trival-agg for short + // This rule simplify LogicalAggregate node by: + // 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node, + // 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs + // 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node + // Push down exprs: + // 1. all group by exprs + // 2. subquery expr in trival-agg + // 3. all input slots of trival-agg + // 4. expr(including subquery) in distinct trival-agg + // Normalize LogicalAggregate's output. + // 1. normalize group by exprs by outputs of bottom LogicalProject + // 2. normalize trival-aggs by outputs of bottom LogicalProject + // 3. build normalized agg outputs + // Pull up exprs: + // normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes: + // 1. simple slots + // 2. aliases + // a. alias with no aggs child + // b. alias with trival-agg child + // c. alias with window-agg - List aggregateOutput = aggregate.getOutputExpressions(); - Set existsAlias = ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance); + // Push down exprs: + // collect group by exprs + Set groupingByExprs = + ImmutableSet.copyOf(aggregate.getGroupByExpressions()); + // collect all trival-agg + List aggregateOutput = aggregate.getOutputExpressions(); List aggFuncs = Lists.newArrayList(); aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); - // we need push down subquery exprs inside non-window and non-distinct agg functions + // collect subquery expr in trival-agg Set subqueryExprs = ExpressionUtils.mutableCollect(aggFuncs.stream() .filter(aggFunc -> !aggFunc.isDistinct()).collect(Collectors.toList()), SubqueryExpr.class::isInstance); - Set groupingByExprs = ImmutableSet.copyOf(aggregate.getGroupByExpressions()); + + // collect all input slots of trival-agg + Set allAggFuncInputSlots = aggFuncs.stream() + .flatMap(agg -> agg.getInputSlots().stream()).collect(Collectors.toSet()); + + // collect expr in distinct trival-agg + Set distinctAggChildExprs = aggFuncs.stream() + .filter(agg -> agg.isDistinct()).flatMap(agg -> agg.children().stream()) + .filter(child -> !(child instanceof SlotReference && child instanceof Literal)) + .collect(Collectors.toSet()); + + Set existsAlias = + ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance); + + // push down group by and subquery exprs to bottom project + Set allPushDownExprs = + Sets.union(Sets.union(groupingByExprs, subqueryExprs), distinctAggChildExprs); NormalizeToSlotContext bottomSlotContext = - NormalizeToSlotContext.buildContext(existsAlias, Sets.union(groupingByExprs, subqueryExprs)); - Set bottomOutputs = - bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs, subqueryExprs)); + NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs); + Set bottomProjects = + bottomSlotContext.pushDownToNamedExpression(allPushDownExprs); + + // create bottom project + Plan bottomPlan; + if (!bottomProjects.isEmpty()) { + bottomPlan = new LogicalProject<>( + ImmutableList.copyOf(Sets.union(bottomProjects, allAggFuncInputSlots)), + aggregate.child()); + } else { + bottomPlan = aggregate.child(); + } // use group by context to normalize agg functions to process // sql like: select sum(a + 1) from t group by a + 1 @@ -127,89 +178,37 @@ public Rule build() { // after normalize: // agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1]) // +-- project((a[#0] + 1)[#1]) - List normalizedAggFuncs = bottomSlotContext.normalizeToUseSlotRef(aggFuncs); - Set bottomProjects = Sets.newHashSet(bottomOutputs); - // TODO: if we have distinct agg, we must push down its children, - // because need use it to generate distribution enforce - // step 1: split agg functions into 2 parts: distinct and not distinct - List distinctAggFuncs = Lists.newArrayList(); - List nonDistinctAggFuncs = Lists.newArrayList(); - for (AggregateFunction aggregateFunction : normalizedAggFuncs) { - if (aggregateFunction.isDistinct()) { - distinctAggFuncs.add(aggregateFunction); - } else { - nonDistinctAggFuncs.add(aggregateFunction); - } - } - // step 2: if we only have one distinct agg function, we do push down for it - if (!distinctAggFuncs.isEmpty()) { - // process distinct normalize and put it back to normalizedAggFuncs - List newDistinctAggFuncs = Lists.newArrayList(); - Map replaceMap = Maps.newHashMap(); - Map aliasCache = Maps.newHashMap(); - for (AggregateFunction distinctAggFunc : distinctAggFuncs) { - List newChildren = Lists.newArrayList(); - for (Expression child : distinctAggFunc.children()) { - if (child instanceof SlotReference || child instanceof Literal) { - newChildren.add(child); - } else { - NamedExpression alias; - if (aliasCache.containsKey(child)) { - alias = aliasCache.get(child); - } else { - alias = new Alias(child); - aliasCache.put(child, alias); - } - bottomProjects.add(alias); - newChildren.add(alias.toSlot()); - } - } - AggregateFunction newDistinctAggFunc = distinctAggFunc.withChildren(newChildren); - replaceMap.put(distinctAggFunc, newDistinctAggFunc); - newDistinctAggFuncs.add(newDistinctAggFunc); - } - aggregateOutput = aggregateOutput.stream() - .map(e -> ExpressionUtils.replace(e, replaceMap)) - .map(NamedExpression.class::cast) - .collect(Collectors.toList()); - distinctAggFuncs = newDistinctAggFuncs; - } - normalizedAggFuncs = Lists.newArrayList(nonDistinctAggFuncs); - normalizedAggFuncs.addAll(distinctAggFuncs); - // TODO: process redundant expressions in aggregate functions children + + // normalize group by exprs by bottomProjects + List normalizedGroupExprs = + bottomSlotContext.normalizeToUseSlotRef(groupingByExprs); + + // normalize trival-aggs by bottomProjects + List normalizedAggFuncs = + bottomSlotContext.normalizeToUseSlotRef(aggFuncs); + // build normalized agg output NormalizeToSlotContext normalizedAggFuncsToSlotContext = NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs); - // agg output include 2 part, normalized group by slots and normalized agg functions + + // agg output include 2 parts + // all bottom projects(group by exprs are included bottom projects) and normalized agg functions List normalizedAggOutput = ImmutableList.builder() - .addAll(bottomOutputs.stream().map(NamedExpression::toSlot).iterator()) - .addAll(normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs)) + .addAll(bottomProjects.stream().map(NamedExpression::toSlot).iterator()) + .addAll(normalizedAggFuncsToSlotContext + .pushDownToNamedExpression(normalizedAggFuncs)) .build(); - // add normalized agg's input slots to bottom projects - Set bottomProjectSlots = bottomProjects.stream() - .map(NamedExpression::toSlot) - .collect(Collectors.toSet()); - Set aggInputSlots = normalizedAggFuncs.stream() - .map(Expression::getInputSlots) - .flatMap(Set::stream) - .filter(e -> !bottomProjectSlots.contains(e)) - .collect(Collectors.toSet()); - bottomProjects.addAll(aggInputSlots); - // build group by exprs - List normalizedGroupExprs = bottomSlotContext.normalizeToUseSlotRef(groupingByExprs); - Plan bottomPlan; - if (!bottomProjects.isEmpty()) { - bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child()); - } else { - bottomPlan = aggregate.child(); - } + // create new agg node + LogicalAggregate newAggregate = + aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan); + // create upper projects by normalize all output exprs in old LogicalAggregate List upperProjects = normalizeOutput(aggregateOutput, bottomSlotContext, normalizedAggFuncsToSlotContext); - return new LogicalProject<>(upperProjects, - aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan)); + // create a parent project node + return new LogicalProject<>(upperProjects, newAggregate); }).toRule(RuleType.NORMALIZE_AGGREGATE); } diff --git a/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy b/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy new file mode 100644 index 000000000000000..546586702e5f4b0 --- /dev/null +++ b/regression-test/suites/nereids_p0/aggregate/agg_distinct_case_when.groovy @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +suite("agg_distinct_case_when") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql "DROP TABLE IF EXISTS agg_test_table_t;" + sql """ + CREATE TABLE `agg_test_table_t` ( + `k1` varchar(65533) NULL, + `k2` text NULL, + `k3` text null, + `k4` text null + ) ENGINE=OLAP + DUPLICATE KEY(`k1`) + COMMENT 'OLAP' + DISTRIBUTED BY HASH(`k1`) BUCKETS 10 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "is_being_synced" = "false", + "storage_format" = "V2", + "light_schema_change" = "true", + "disable_auto_compaction" = "false", + "enable_single_replica_compaction" = "false" + ); + """ + + sql """insert into agg_test_table_t(`k1`,`k2`,`k3`) values('20231026221524','PA','adigu1bububud');""" + sql """ + select + count(distinct case when t.k2='PA' and loan_date=to_date(substr(t.k1,1,8)) then t.k2 end ) + from ( + select substr(k1,1,8) loan_date,k3,k2,k1 from agg_test_table_t) t + group by + substr(t.k1,1,8);""" + + sql "DROP TABLE IF EXISTS agg_test_table_t;" +}