diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountOnIndex.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountOnIndex.java index 284011b2c27dfcb..8ccc8001896f8f4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountOnIndex.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownCountOnIndex.java @@ -23,8 +23,6 @@ import org.apache.doris.nereids.trees.expressions.Match; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -34,7 +32,6 @@ import java.util.List; import java.util.Set; -import java.util.stream.Collectors; /** * Rewriter of pushing down count on index. @@ -47,37 +44,22 @@ public Rule build() { logicalProject( logicalFilter( logicalOlapScan().whenNot(LogicalOlapScan::isCountOnIndexPushedDown) - ) + ).when(filter -> containsMatchExpression(filter.getExpressions()) + && filter.getExpressions().size() == 1) )) .when(agg -> enablePushDownCountOnIndex()) - .thenApply(ctx -> { - LogicalAggregate>> aggregate = ctx.root; - final LogicalAggregate canNotPush = aggregate; - - LogicalProject> project = aggregate.child(); + .when(agg -> agg.getGroupByExpressions().size() == 0) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + return !funcs.isEmpty() && funcs.stream() + .allMatch(f -> f instanceof Count && !f.isDistinct()); + }) + .then(agg -> { + LogicalProject> project = agg.child(); LogicalFilter filter = project.child(); LogicalOlapScan olapScan = filter.child(); - Set aggregateFunctions = aggregate.getAggregateFunctions(); - Set> functionClasses = aggregateFunctions - .stream() - .map(AggregateFunction::getClass) - .collect(Collectors.toSet()); - //if and only if there is a count on agg - if (!(functionClasses.size() == 1 && functionClasses.contains(Count.class))) { - return canNotPush; - } - if (aggregateFunctions.stream().anyMatch(fun -> fun.arity() > 1)) { - return canNotPush; - } - List expressions = filter.getExpressions(); - if (expressions.size() > 1) { - return canNotPush; - } - if (!containsMatchExpression(expressions)) { - return canNotPush; - } - return aggregate.withChildren(ImmutableList.of( + return agg.withChildren(ImmutableList.of( project.withChildren( ImmutableList.of( filter.withChildren(olapScan.withPushDownCountOnIndex(true))