From 80baf7858a3e811aa4e6af5b82210f19c68b24d5 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Thu, 10 Oct 2024 18:06:09 +0800 Subject: [PATCH] fix ut --- ...st.java => PushCountIntoUnionAllTest.java} | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) rename fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/{PushDownCountIntoUnionAllTest.java => PushCountIntoUnionAllTest.java} (89%) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountIntoUnionAllTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java similarity index 89% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountIntoUnionAllTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java index 15a5d1eacf171f..8e46cbb17c1cb4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownCountIntoUnionAllTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushCountIntoUnionAllTest.java @@ -26,7 +26,7 @@ import org.junit.jupiter.api.Test; -public class PushDownCountIntoUnionAllTest extends TestWithFeService implements MemoPatternMatchSupported { +public class PushCountIntoUnionAllTest extends TestWithFeService implements MemoPatternMatchSupported { @Override protected void runBeforeAll() throws Exception { createDatabase("test"); @@ -42,7 +42,7 @@ protected void runBeforeAll() throws Exception { @Test void testPushCountStar() { - String sql = "select count(1) from (select id,a from t1 union all select id,a from t1 where id>10) t;"; + String sql = "select id,count(1) from (select id,a from t1 union all select id,a from t1 where id>10) t group by id;"; PlanChecker.from(connectContext) .analyze(sql) .rewrite() @@ -52,7 +52,7 @@ void testPushCountStar() { logicalAggregate().when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Count.class))) ).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class)) ); - String sql2 = "select count(*) from (select id,a from t1 union all select id,a from t1 where id>10) t;"; + String sql2 = "select id,count(*) from (select id,a from t1 union all select id,a from t1 where id>10) t group by id;"; PlanChecker.from(connectContext) .analyze(sql2) .rewrite() @@ -63,6 +63,19 @@ void testPushCountStar() { ).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class)) ); } + // TODO: not push because after column prune, agg-union transform to agg-project(1)-union, not match rule pattern. + @Test + void testPushCountStarNotPush() { + String sql = "select count(1) from (select id,a from t1 union all select id,a from t1 where id>10) t;"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .nonMatch( + logicalAggregate( + logicalUnion(logicalAggregate(), logicalAggregate()) + ).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class)) + ); + } @Test void testPushCountColumn() {