Skip to content

Commit

Permalink
add logic for agg-project(1)-union
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Oct 10, 2024
1 parent 80baf78 commit 3292303
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.util.ExpressionUtils;
Expand Down Expand Up @@ -61,13 +63,19 @@
* +--LogicalAggregate (groupByExpr=[a#7], outputExpr=[a#7, count(a#7) AS `count(a)`#18]]
* +--child3
*/
public class PushCountIntoUnionAll extends OneRewriteRuleFactory {
public class PushCountIntoUnionAll implements RewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate(logicalUnion().when(this::checkUnion))
public List<Rule> buildRules() {
return ImmutableList.of(logicalAggregate(logicalUnion().when(this::checkUnion))
.when(this::checkAgg)
.then(this::doPush)
.toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL);
.toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL),
logicalAggregate(logicalProject(logicalUnion().when(this::checkUnion)))
.when(this::checkAgg)
.when(this::checkProjectUseless)
.then(this::removeProjectAndPush)
.toRule(RuleType.PUSH_COUNT_INTO_UNION_ALL)
);
}

private Plan doPush(LogicalAggregate<LogicalUnion> agg) {
Expand Down Expand Up @@ -161,6 +169,31 @@ private boolean checkAgg(LogicalAggregate aggregate) {
return !hasUnsuportedAggFunc(aggregate);
}

private boolean checkProjectUseless(LogicalAggregate<LogicalProject<LogicalUnion>> agg) {
LogicalProject<LogicalUnion> project = agg.child();
if (project.getProjects().size() != 1) {
return false;
}
if (!(project.getProjects().get(0) instanceof Alias)) {
return false;
}
Alias alias = (Alias) project.getProjects().get(0);
if (!alias.child(0).equals(new TinyIntLiteral((byte) 1))) {
return false;
}
List<NamedExpression> aggOutputs = agg.getOutputExpressions();
Slot slot = project.getOutput().get(0);
if (ExpressionUtils.anyMatch(aggOutputs, expr -> expr.equals(slot))) {
return false;
}
return true;
}

private Plan removeProjectAndPush(LogicalAggregate<LogicalProject<LogicalUnion>> agg) {
Plan afterRemove = agg.withChildren(agg.child().child());
return doPush((LogicalAggregate<LogicalUnion>) afterRemove);
}

private boolean hasUnsuportedAggFunc(LogicalAggregate aggregate) {
// only support count, not suport sum,min... and not support count(distinct)
return ExpressionUtils.deapAnyMatch(aggregate.getOutputExpressions(), expr -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,23 @@ void testPushCountStar() {
).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
);
}
// TODO: not push because after column prune, agg-union transform to agg-project(1)-union, not match rule pattern.

@Test
void testPushCountStarNotPush() {
void testPushCountStarNoOtherColumn() {
String sql = "select count(1) from (select id,a from t1 union all select id,a from t1 where id>10) t;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(
.matches(
logicalAggregate(
logicalUnion(logicalAggregate(), logicalAggregate())
).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
);
String sql2 = "select count(*) from (select id,a from t1 union all select id,a from t1 where id>10) t;";
PlanChecker.from(connectContext)
.analyze(sql2)
.rewrite()
.matches(
logicalAggregate(
logicalUnion(logicalAggregate(), logicalAggregate())
).when(agg -> ExpressionUtils.containsType(agg.getOutputExpressions(), Sum0.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,6 @@ PhysicalResultSink
----------------filter((mal_test_push_count.a = 1))
------------------PhysicalOlapScan[mal_test_push_count]

-- !test_count_star --
32

Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,8 @@ suite("push_count_into_union_all") {
select a,c1 from (
select a,count(*) c1 from (select a,b from mal_test_push_count where a>1 union all select a,b from mal_test_push_count where a<100
union all select a,b from mal_test_push_count where a=1 ) t group by a) outer_table order by 1,2;"""

qt_test_count_star """
select count(*) from (select a,b from mal_test_push_count where a>1 union all select a,b from mal_test_push_count where a<100
union all select a,b from mal_test_push_count where a=1) t order by 1,2;"""
}

0 comments on commit 3292303

Please sign in to comment.