Skip to content

Commit

Permalink
[Feat](nereids) add max/min filter push down rewrite rule
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Aug 21, 2024
1 parent 77ca637 commit 3b81ab7
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
public class MaxMinFilterPushDown extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter(logicalAggregate())
return logicalFilter(logicalAggregate().whenNot(agg -> agg.getGroupByExpressions().isEmpty()))
.then(this::pushDownMaxMinFilter)
.toRule(RuleType.MAX_MIN_FILTER_PUSH_DOWN);
}
Expand Down Expand Up @@ -110,7 +110,7 @@ private Plan pushDownMaxMinFilter(LogicalFilter<LogicalAggregate<Plan>> filter)
newPredicate = new LessThanEqual(func.child(0), originConjunct.child(1));
}
}
Preconditions.checkState(newPredicate != null);
Preconditions.checkState(newPredicate != null, "newPredicate is null");
LogicalFilter<Plan> newPushDownFilter = new LogicalFilter<>(ImmutableSet.of(newPredicate), aggChild);
LogicalAggregate<Plan> newAgg = agg.withChildren(ImmutableList.of(newPushDownFilter));
return PlanUtils.filterOrSelf(newUpperConjuncts, newAgg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,37 +51,64 @@ public void testMinRewrite() {
}

@Test
public void testMaxNotRewrite0() {
public void testNotRewriteBecauseFuncIsMoreThanOne1() {
String sql = "select id, min(score), max(name) from max_t group by id having min(score)<10 and max(name)>'abc'";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.nonMatch(logicalFilter(logicalOlapScan()));
}
@Test
public void testNotRewriteBecauseFuncIsMoreThanOne2() {
String sql = "select id, min(score), min(name) from max_t group by id having min(score)<10 and min(name)<'abc'";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.nonMatch(logicalFilter(logicalOlapScan()));
}

@Test
public void testMaxNotRewriteBecauseLessThan() {
String sql = "select id, max(score) from max_t group by id having max(score)<10";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.nonMatch(logicalFilter(logicalOlapScan()));
}

@Test
public void testMinNotRewrite1() {
public void testMinNotRewriteBecauseGreaterThan() {
String sql = "select id, min(score) from max_t group by id having min(score)>10";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.nonMatch(logicalFilter(logicalOlapScan()));
}

@Test
public void testMinNotRewrite2() {
String sql = "select id, min(score), max(score) from max_t group by id having min(score)>10";
public void testMinNotRewriteBecauseHasMaxFunc() {
String sql = "select id, min(score), max(score) from max_t group by id having min(score)<10";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.nonMatch(logicalFilter(logicalOlapScan()));
}

@Test
public void testMinNotRewrite3() {
String sql = "select id, min(score), count(score) from max_t group by id having min(score)>10";
public void testMinNotRewriteBecauseHasCountFunc() {
String sql = "select id, min(score), count(score) from max_t group by id having min(score)<10";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.nonMatch(logicalFilter(logicalOlapScan()));
}

@Test
public void testMinNotRewrite4() {
public void testNotRewriteBecauseConjunctLeftNotSlot() {
String sql = "select id, max(score) from max_t group by id having abs(max(score))>10";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.nonMatch(logicalFilter(logicalOlapScan()));
}

@Test
public void testRewriteAggFuncHasExpr() {
String sql = "select id, max(score+1) from max_t group by id having max(score+1)>10";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.matches(logicalFilter(logicalOlapScan()).when(filter -> filter.getConjuncts().size() == 1));
}

@Test
public void testNotRewriteScalarAgg() {
String sql = "select max(score+1) from max_t having max(score+1)>10";
PlanChecker.from(connectContext).analyze(sql).rewrite()
.nonMatch(logicalFilter(logicalOlapScan()));
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !scalar_agg_empty_table --
PhysicalResultSink
--filter((min(value1) < 20))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalEmptyRelation

-- !min --
PhysicalResultSink
--hashAgg[GLOBAL]
Expand Down Expand Up @@ -78,38 +85,40 @@ PhysicalResultSink

-- !min_scalar_agg --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------filter((max_min_filter_push_down1.value1 < 40))
--------PhysicalOlapScan[max_min_filter_push_down1]
--filter((min(value1) < 40))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[max_min_filter_push_down1]

-- !max_scalar_agg --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------filter((max_min_filter_push_down1.value1 > 40))
--------PhysicalOlapScan[max_min_filter_push_down1]
--filter((max(value1) > 40))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[max_min_filter_push_down1]

-- !max_scalar_agg --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------filter((max_min_filter_push_down1.value1 > 40))
--------PhysicalOlapScan[max_min_filter_push_down1]
--filter((max(value1) > 40))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[max_min_filter_push_down1]

-- !min_equal_scalar_agg --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------filter((max_min_filter_push_down1.value1 <= 20))
--------PhysicalOlapScan[max_min_filter_push_down1]
--filter((min(value1) <= 20))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[max_min_filter_push_down1]

-- !max_equal_scalar_agg --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------filter((max_min_filter_push_down1.value1 >= 40))
--------PhysicalOlapScan[max_min_filter_push_down1]
--filter((max(value1) >= 40))
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[max_min_filter_push_down1]

-- !scalar_agg_empty_table_res --

-- !min_res --
1 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@ suite("max_min_filter_push_down") {
INSERT INTO max_min_filter_push_down1 (id, value1, value2) VALUES
(1, 10, 'A'),(1, 11, 'A'),(2, 20, 'B'),(2, 73, 'B'),(2, 19, 'B'),(3, 30, 'C'),(3, 61, 'C'),(4, 40, 'D'),(4, 43, 'D'),(4, 45, 'D');
"""
sql "drop table if exists max_min_filter_push_down_empty"
sql "create table max_min_filter_push_down_empty like max_min_filter_push_down1"

qt_scalar_agg_empty_table """
explain shape plan
select min(value1) from max_min_filter_push_down_empty having min(value1) <40 and min(value1) <20;
"""
qt_min """
explain shape plan
select id,min(value1) from max_min_filter_push_down1 group by id having min(value1) <40 and min(value1) <20;
Expand Down Expand Up @@ -105,7 +111,9 @@ suite("max_min_filter_push_down") {
select max(value1) from max_min_filter_push_down1 having max(value1) >=40;
"""


qt_scalar_agg_empty_table_res """
select min(value1) from max_min_filter_push_down_empty having min(value1) <40 and min(value1) <20;
"""
qt_min_res """
select id,min(value1) from max_min_filter_push_down1 group by id having min(value1) <40 and min(value1) <20 order by 1,2;
"""
Expand Down

0 comments on commit 3b81ab7

Please sign in to comment.