diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MaxMinFilterPushDown.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MaxMinFilterPushDown.java index c137a5ffa3a97bf..a54c3785b35a723 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MaxMinFilterPushDown.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MaxMinFilterPushDown.java @@ -63,7 +63,7 @@ public class MaxMinFilterPushDown extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalFilter(logicalAggregate()) + return logicalFilter(logicalAggregate().whenNot(agg -> agg.getGroupByExpressions().isEmpty())) .then(this::pushDownMaxMinFilter) .toRule(RuleType.MAX_MIN_FILTER_PUSH_DOWN); } @@ -110,7 +110,7 @@ private Plan pushDownMaxMinFilter(LogicalFilter> filter) newPredicate = new LessThanEqual(func.child(0), originConjunct.child(1)); } } - Preconditions.checkState(newPredicate != null); + Preconditions.checkState(newPredicate != null, "newPredicate is null"); LogicalFilter newPushDownFilter = new LogicalFilter<>(ImmutableSet.of(newPredicate), aggChild); LogicalAggregate newAgg = agg.withChildren(ImmutableList.of(newPushDownFilter)); return PlanUtils.filterOrSelf(newUpperConjuncts, newAgg); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MaxMinFilterPushDownTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MaxMinFilterPushDownTest.java index 0c8cced405d317a..6025acf0d32bc0c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MaxMinFilterPushDownTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/MaxMinFilterPushDownTest.java @@ -51,37 +51,64 @@ public void testMinRewrite() { } @Test - public void testMaxNotRewrite0() { + public void testNotRewriteBecauseFuncIsMoreThanOne1() { + String sql = "select id, min(score), max(name) from max_t group by id having min(score)<10 and max(name)>'abc'"; + PlanChecker.from(connectContext).analyze(sql).rewrite() + .nonMatch(logicalFilter(logicalOlapScan())); + } + @Test + public void testNotRewriteBecauseFuncIsMoreThanOne2() { + String sql = "select id, min(score), min(name) from max_t group by id having min(score)<10 and min(name)<'abc'"; + PlanChecker.from(connectContext).analyze(sql).rewrite() + .nonMatch(logicalFilter(logicalOlapScan())); + } + + @Test + public void testMaxNotRewriteBecauseLessThan() { String sql = "select id, max(score) from max_t group by id having max(score)<10"; PlanChecker.from(connectContext).analyze(sql).rewrite() .nonMatch(logicalFilter(logicalOlapScan())); } @Test - public void testMinNotRewrite1() { + public void testMinNotRewriteBecauseGreaterThan() { String sql = "select id, min(score) from max_t group by id having min(score)>10"; PlanChecker.from(connectContext).analyze(sql).rewrite() .nonMatch(logicalFilter(logicalOlapScan())); } @Test - public void testMinNotRewrite2() { - String sql = "select id, min(score), max(score) from max_t group by id having min(score)>10"; + public void testMinNotRewriteBecauseHasMaxFunc() { + String sql = "select id, min(score), max(score) from max_t group by id having min(score)<10"; PlanChecker.from(connectContext).analyze(sql).rewrite() .nonMatch(logicalFilter(logicalOlapScan())); } @Test - public void testMinNotRewrite3() { - String sql = "select id, min(score), count(score) from max_t group by id having min(score)>10"; + public void testMinNotRewriteBecauseHasCountFunc() { + String sql = "select id, min(score), count(score) from max_t group by id having min(score)<10"; PlanChecker.from(connectContext).analyze(sql).rewrite() .nonMatch(logicalFilter(logicalOlapScan())); } @Test - public void testMinNotRewrite4() { + public void testNotRewriteBecauseConjunctLeftNotSlot() { String sql = "select id, max(score) from max_t group by id having abs(max(score))>10"; PlanChecker.from(connectContext).analyze(sql).rewrite() .nonMatch(logicalFilter(logicalOlapScan())); } + + @Test + public void testRewriteAggFuncHasExpr() { + String sql = "select id, max(score+1) from max_t group by id having max(score+1)>10"; + PlanChecker.from(connectContext).analyze(sql).rewrite() + .matches(logicalFilter(logicalOlapScan()).when(filter -> filter.getConjuncts().size() == 1)); + } + + @Test + public void testNotRewriteScalarAgg() { + String sql = "select max(score+1) from max_t having max(score+1)>10"; + PlanChecker.from(connectContext).analyze(sql).rewrite() + .nonMatch(logicalFilter(logicalOlapScan())); + } } diff --git a/regression-test/data/nereids_rules_p0/max_min_filter_push_down/max_min_filter_push_down.out b/regression-test/data/nereids_rules_p0/max_min_filter_push_down/max_min_filter_push_down.out index 4ad1baf33872005..0b650210f0a0e2d 100644 --- a/regression-test/data/nereids_rules_p0/max_min_filter_push_down/max_min_filter_push_down.out +++ b/regression-test/data/nereids_rules_p0/max_min_filter_push_down/max_min_filter_push_down.out @@ -1,4 +1,11 @@ -- This file is automatically generated. You should know what you did if you want to edit this +-- !scalar_agg_empty_table -- +PhysicalResultSink +--filter((min(value1) < 20)) +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------PhysicalEmptyRelation + -- !min -- PhysicalResultSink --hashAgg[GLOBAL] @@ -78,38 +85,40 @@ PhysicalResultSink -- !min_scalar_agg -- PhysicalResultSink ---hashAgg[GLOBAL] -----hashAgg[LOCAL] -------filter((max_min_filter_push_down1.value1 < 40)) ---------PhysicalOlapScan[max_min_filter_push_down1] +--filter((min(value1) < 40)) +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------PhysicalStorageLayerAggregate[max_min_filter_push_down1] -- !max_scalar_agg -- PhysicalResultSink ---hashAgg[GLOBAL] -----hashAgg[LOCAL] -------filter((max_min_filter_push_down1.value1 > 40)) ---------PhysicalOlapScan[max_min_filter_push_down1] +--filter((max(value1) > 40)) +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------PhysicalStorageLayerAggregate[max_min_filter_push_down1] -- !max_scalar_agg -- PhysicalResultSink ---hashAgg[GLOBAL] -----hashAgg[LOCAL] -------filter((max_min_filter_push_down1.value1 > 40)) ---------PhysicalOlapScan[max_min_filter_push_down1] +--filter((max(value1) > 40)) +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------PhysicalStorageLayerAggregate[max_min_filter_push_down1] -- !min_equal_scalar_agg -- PhysicalResultSink ---hashAgg[GLOBAL] -----hashAgg[LOCAL] -------filter((max_min_filter_push_down1.value1 <= 20)) ---------PhysicalOlapScan[max_min_filter_push_down1] +--filter((min(value1) <= 20)) +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------PhysicalStorageLayerAggregate[max_min_filter_push_down1] -- !max_equal_scalar_agg -- PhysicalResultSink ---hashAgg[GLOBAL] -----hashAgg[LOCAL] -------filter((max_min_filter_push_down1.value1 >= 40)) ---------PhysicalOlapScan[max_min_filter_push_down1] +--filter((max(value1) >= 40)) +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------PhysicalStorageLayerAggregate[max_min_filter_push_down1] + +-- !scalar_agg_empty_table_res -- -- !min_res -- 1 10 diff --git a/regression-test/suites/nereids_rules_p0/max_min_filter_push_down/max_min_filter_push_down.groovy b/regression-test/suites/nereids_rules_p0/max_min_filter_push_down/max_min_filter_push_down.groovy index b1acd02d0e5c6d9..635fd180aa9b91b 100644 --- a/regression-test/suites/nereids_rules_p0/max_min_filter_push_down/max_min_filter_push_down.groovy +++ b/regression-test/suites/nereids_rules_p0/max_min_filter_push_down/max_min_filter_push_down.groovy @@ -32,7 +32,13 @@ suite("max_min_filter_push_down") { INSERT INTO max_min_filter_push_down1 (id, value1, value2) VALUES (1, 10, 'A'),(1, 11, 'A'),(2, 20, 'B'),(2, 73, 'B'),(2, 19, 'B'),(3, 30, 'C'),(3, 61, 'C'),(4, 40, 'D'),(4, 43, 'D'),(4, 45, 'D'); """ + sql "drop table if exists max_min_filter_push_down_empty" + sql "create table max_min_filter_push_down_empty like max_min_filter_push_down1" + qt_scalar_agg_empty_table """ + explain shape plan + select min(value1) from max_min_filter_push_down_empty having min(value1) <40 and min(value1) <20; + """ qt_min """ explain shape plan select id,min(value1) from max_min_filter_push_down1 group by id having min(value1) <40 and min(value1) <20; @@ -105,7 +111,9 @@ suite("max_min_filter_push_down") { select max(value1) from max_min_filter_push_down1 having max(value1) >=40; """ - + qt_scalar_agg_empty_table_res """ + select min(value1) from max_min_filter_push_down_empty having min(value1) <40 and min(value1) <20; + """ qt_min_res """ select id,min(value1) from max_min_filter_push_down1 group by id having min(value1) <40 and min(value1) <20 order by 1,2; """