diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 7bbbc7841e8235..e9944487805bc2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; @@ -108,47 +109,72 @@ public List buildRules() { logicalAggregate( logicalFilter( logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable) - ).when(filter -> !filter.getConjuncts().isEmpty())) - .when(agg -> enablePushDownCountOnIndex()) - .when(agg -> agg.getGroupByExpressions().isEmpty()) - .when(agg -> { - Set funcs = agg.getAggregateFunctions(); - return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Count && !f.isDistinct() && (((Count) f).isStar() - || f.children().isEmpty() - || (f.children().size() == 1 && f.child(0) instanceof Literal) - || f.child(0) instanceof Slot)); - }) - .thenApply(ctx -> { - LogicalAggregate> agg = ctx.root; - LogicalFilter filter = agg.child(); - LogicalOlapScan olapScan = filter.child(); - return pushdownCountOnIndex(agg, null, filter, olapScan, ctx.cascadesContext); - }) + ) + ) + .when(agg -> enablePushDownCountOnIndex()) + .when(agg -> agg.getGroupByExpressions().isEmpty()) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + if (funcs.isEmpty() || !funcs.stream() + .allMatch(f -> f instanceof Count && !f.isDistinct() && (((Count) f).isStar() + || f.children().isEmpty() + || (f.children().size() == 1 && f.child(0) instanceof Literal) + || f.child(0) instanceof Slot))) { + return false; + } + Set conjuncts = agg.child().getConjuncts(); + if (conjuncts.isEmpty()) { + return false; + } + + Set aggSlots = funcs.stream() + .flatMap(f -> f.getInputSlots().stream()) + .collect(Collectors.toSet()); + return conjuncts.stream().allMatch(expr -> checkSlotInOrExpression(expr, aggSlots)); + }) + .thenApply(ctx -> { + LogicalAggregate> agg = ctx.root; + LogicalFilter filter = agg.child(); + LogicalOlapScan olapScan = filter.child(); + return pushdownCountOnIndex(agg, null, filter, olapScan, ctx.cascadesContext); + }) ), RuleType.COUNT_ON_INDEX.build( logicalAggregate( logicalProject( logicalFilter( logicalOlapScan().when(this::isDupOrMowKeyTable).when(this::isInvertedIndexEnabledOnTable) - ).when(filter -> !filter.getConjuncts().isEmpty()))) - .when(agg -> enablePushDownCountOnIndex()) - .when(agg -> agg.getGroupByExpressions().isEmpty()) - .when(agg -> { - Set funcs = agg.getAggregateFunctions(); - return !funcs.isEmpty() && funcs.stream() - .allMatch(f -> f instanceof Count && !f.isDistinct() && (((Count) f).isStar() - || f.children().isEmpty() - || (f.children().size() == 1 && f.child(0) instanceof Literal) - || f.child(0) instanceof Slot)); - }) - .thenApply(ctx -> { - LogicalAggregate>> agg = ctx.root; - LogicalProject> project = agg.child(); - LogicalFilter filter = project.child(); - LogicalOlapScan olapScan = filter.child(); - return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext); - }) + ) + ) + ) + .when(agg -> enablePushDownCountOnIndex()) + .when(agg -> agg.getGroupByExpressions().isEmpty()) + .when(agg -> { + Set funcs = agg.getAggregateFunctions(); + if (funcs.isEmpty() || !funcs.stream() + .allMatch(f -> f instanceof Count && !f.isDistinct() && (((Count) f).isStar() + || f.children().isEmpty() + || (f.children().size() == 1 && f.child(0) instanceof Literal) + || f.child(0) instanceof Slot))) { + return false; + } + Set conjuncts = agg.child().child().getConjuncts(); + if (conjuncts.isEmpty()) { + return false; + } + + Set aggSlots = funcs.stream() + .flatMap(f -> f.getInputSlots().stream()) + .collect(Collectors.toSet()); + return conjuncts.stream().allMatch(expr -> checkSlotInOrExpression(expr, aggSlots)); + }) + .thenApply(ctx -> { + LogicalAggregate>> agg = ctx.root; + LogicalProject> project = agg.child(); + LogicalFilter filter = project.child(); + LogicalOlapScan olapScan = filter.child(); + return pushdownCountOnIndex(agg, project, filter, olapScan, ctx.cascadesContext); + }) ), RuleType.STORAGE_LAYER_AGGREGATE_MINMAX_ON_UNIQUE_WITHOUT_PROJECT.build( logicalAggregate( @@ -331,6 +357,22 @@ private boolean enablePushDownCountOnIndex() { return connectContext != null && connectContext.getSessionVariable().isEnablePushDownCountOnIndex(); } + private boolean checkSlotInOrExpression(Expression expr, Set aggSlots) { + if (expr instanceof Or) { + Set slots = expr.getInputSlots(); + if (!slots.stream().allMatch(aggSlots::contains)) { + return false; + } + } else { + for (Expression child : expr.children()) { + if (!checkSlotInOrExpression(child, aggSlots)) { + return false; + } + } + } + return true; + } + private boolean isDupOrMowKeyTable(LogicalOlapScan logicalScan) { if (logicalScan != null) { KeysType keysType = logicalScan.getTable().getKeysType(); diff --git a/regression-test/data/inverted_index_p0/test_count_on_index_2.out b/regression-test/data/inverted_index_p0/test_count_on_index_2.out index 94d2a83388b38f..de74ba29ffef4d 100644 --- a/regression-test/data/inverted_index_p0/test_count_on_index_2.out +++ b/regression-test/data/inverted_index_p0/test_count_on_index_2.out @@ -101,3 +101,12 @@ -- !sql -- 3 +-- !sql -- +1 + +-- !sql -- +1 + +-- !sql -- +1 + diff --git a/regression-test/suites/inverted_index_p0/test_count_on_index_2.groovy b/regression-test/suites/inverted_index_p0/test_count_on_index_2.groovy index 851c9120aa2037..8f95e3cb13dcb5 100644 --- a/regression-test/suites/inverted_index_p0/test_count_on_index_2.groovy +++ b/regression-test/suites/inverted_index_p0/test_count_on_index_2.groovy @@ -201,6 +201,35 @@ suite("test_count_on_index_2", "p0"){ qt_sql """ select count() from ${indexTbName3} where (a >= 10 and a < 20) and (b >= 5 and b < 14) and (c >= 16 and c < 25); """ qt_sql """ select count() from ${indexTbName3} where (a >= 10 and a < 20) and (b >= 5 and b < 16) and (c >= 13 and c < 25); """ + sql """ DROP TABLE IF EXISTS tt """ + sql """ + CREATE TABLE `tt` ( + `a` int NULL, + `b` int NULL, + `c` int NULL, + INDEX col_c (`b`) USING INVERTED, + INDEX col_b (`c`) USING INVERTED + ) ENGINE=OLAP + DUPLICATE KEY(`a`) + COMMENT 'OLAP' + DISTRIBUTED BY RANDOM BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + sql """ insert into tt values (20, 23, 30); """ + sql """ insert into tt values (20, null, 30); """ + qt_sql """ select count(b) from tt where b = 23 or c = 30; """ + qt_sql """ select count(b) from tt where b = 23 and (c = 20 or c = 30); """ + explain { + sql("select count(b) from tt where b = 23 and (c = 20 or c = 30);") + contains "COUNT_ON_INDEX" + } + explain { + sql("select count(b) from tt where b = 23 or b = 30;") + contains "COUNT_ON_INDEX" + } } finally { //try_sql("DROP TABLE IF EXISTS ${testTable}") }