diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 2361c2763727bfd..c5977be0b4fbefc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization; import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit; +import org.apache.doris.nereids.rules.rewrite.AddProjectToAggregateChild; import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType; import org.apache.doris.nereids.rules.rewrite.AdjustNullable; import org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction; @@ -155,6 +156,7 @@ public class Rewriter extends AbstractBatchJobExecutor { // but we need to do some normalization before subquery unnesting, // such as extract common expression. new ExpressionNormalizationAndOptimization(), + new AddProjectToAggregateChild(), new AvgDistinctToSumDivCount(), new CountDistinctRewrite(), new ExtractFilterFromCrossJoin() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 696463523f6904e..29e6b08ff9283b6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -77,6 +77,7 @@ public enum RuleType { ADJUST_NULLABLE_FOR_AGGREGATE_SLOT(RuleTypeClass.REWRITE), ADJUST_NULLABLE_FOR_HAVING_SLOT(RuleTypeClass.REWRITE), ADJUST_NULLABLE_FOR_REPEAT_SLOT(RuleTypeClass.REWRITE), + ADD_AGG_PROJECT(RuleTypeClass.REWRITE), ADD_DEFAULT_LIMIT(RuleTypeClass.REWRITE), CHECK_ROW_POLICY(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectToAggregateChild.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectToAggregateChild.java new file mode 100644 index 000000000000000..8b495bf1622e2a5 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AddProjectToAggregateChild.java @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Length; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; + +import java.util.List; + +/** + * Rewrite aggregate function with length function to a seperated projection, which can accelerate when using pipeline + * engine. Pipeline engine can not accelerate length function in aggregate function but can accelerate in projection. + *

+ * agg(length()) ==> agg(alias)->project(length()) + */ +public class AddProjectToAggregateChild extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate().when(AddProjectToAggregateChild::containsLengthString).then(agg -> { + List outputExpressions = agg.getOutputExpressions(); + Builder newOutputs + = ImmutableList.builderWithExpectedSize(outputExpressions.size()); + Builder projectOutputs = ImmutableList.builder(); + for (NamedExpression outputExpression : outputExpressions) { + projectOutputs.add(outputExpression); + NamedExpression newOutput = (NamedExpression) outputExpression.rewriteUp(expr -> { + if (expr instanceof Length) { + NamedExpression alias = new Alias(expr, ((BoundFunction) expr).getName()); + projectOutputs.add(alias); + return alias.toSlot(); + } + return expr; + }); + newOutputs.add(newOutput); + } + LogicalProject project = new LogicalProject<>(projectOutputs.build(), agg.child()); + return agg.withAggOutputChild(newOutputs.build(), project); + }).toRule(RuleType.ADD_AGG_PROJECT); + } + + private static boolean containsLengthString(LogicalAggregate agg) { + for (NamedExpression ne : agg.getOutputExpressions()) { + boolean needRewrite = ne.anyMatch(expr -> { + if (expr instanceof AggregateFunction && expr.containsType(Length.class) + && !expr.containsType(Alias.class)) { + return true; + } + return false; + }); + if (needRewrite) { + return true; + } + } + return false; + } +} diff --git a/regression-test/suites/nereids_p0/aggregate/agg_length_function.groovy b/regression-test/suites/nereids_p0/aggregate/agg_length_function.groovy new file mode 100644 index 000000000000000..cbdb09d1e9e960c --- /dev/null +++ b/regression-test/suites/nereids_p0/aggregate/agg_length_function.groovy @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("agg_length_function") { + sql 'set enable_nereids_planner=true' + sql 'set enable_fallback_to_original_planner=false' + + sql 'drop table if exists test_agg_length_function;' + + sql '''create table test_agg_length_function(k0 int, k1 string, k2 varchar, k3 char(5)) distributed by hash(k0) buckets 3 properties('replication_num' = '1');''' + + // for string type + def res = sql ''' + explain select AVG(length(k1)) from test_agg_length_function; + ''' + assertTrue(res.toString().contains("length[#")) + +}