Skip to content

Commit

Permalink
[Improvement](Nereids) Support aggregate rewrite by materialized view…
Browse files Browse the repository at this point in the history
… with complex expression (#30440)

materialized view definition is

>            select
>            sum(o_totalprice) as sum_total,
>            max(o_totalprice) as max_total,
>            min(o_totalprice) as min_total,
>           count(*) as count_all,
>            bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) >cnt_1,
>            bitmap_union(to_bitmap(case when o_shippriority > 2 and o_orderkey IN (2) then o_custkey else null end)) as >cnt_2
>            from lineitem
>            left join orders on l_orderkey = o_orderkey and l_shipdate = o_orderdate;
   

the query following can be rewritten by materialized view above.
it use the aggregate fuction arithmetic calculation in the select 

>            select
>            count(distinct case when O_SHIPPRIORITY > 2 and o_orderkey IN (2) then o_custkey else null end) as cnt_2,
>            (sum(o_totalprice) + min(o_totalprice)) * count(*),
>            min(o_totalprice) + count(distinct case when O_SHIPPRIORITY > 2 and o_orderkey IN (2) then o_custkey else null >end)
>            from lineitem
>            left join orders on l_orderkey = o_orderkey and l_shipdate = o_orderdate;
  • Loading branch information
seawinde authored Jan 29, 2024
1 parent 193f62e commit 885d125
Show file tree
Hide file tree
Showing 5 changed files with 402 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
Expand Down Expand Up @@ -65,6 +66,8 @@ public abstract class AbstractMaterializedViewAggregateRule extends AbstractMate

protected static final Multimap<Function, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create();
protected static final AggregateExpressionRewriter AGGREGATE_EXPRESSION_REWRITER =
new AggregateExpressionRewriter();

static {
// support count distinct roll up
Expand Down Expand Up @@ -156,7 +159,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
= topPlanSplitToGroupAndFunction(queryTopPlanAndAggPair);
Set<? extends Expression> queryTopPlanFunctionSet = queryGroupAndFunctionPair.value();
// try to rewrite, contains both roll up aggregate functions and aggregate group expression
List<NamedExpression> finalAggregateExpressions = new ArrayList<>();
List<NamedExpression> finalOutputExpressions = new ArrayList<>();
List<Expression> finalGroupExpressions = new ArrayList<>();
List<? extends Expression> queryExpressions = queryTopPlan.getExpressions();
// permute the mv expr mapping to query based
Expand All @@ -169,42 +172,40 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
topExpression,
queryTopPlan);
// try to roll up
List<Object> queryFunctions =
queryFunctionShuttled.collectFirst(expr -> expr instanceof AggregateFunction);
if (queryFunctions.isEmpty()) {
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("Can not found query function",
String.format("queryFunctionShuttled = %s", queryFunctionShuttled)));
return null;
}
Function rollupAggregateFunction = rollup((AggregateFunction) queryFunctions.get(0),
queryFunctionShuttled, mvExprToMvScanExprQueryBased);
if (rollupAggregateFunction == null) {
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(
false, mvExprToMvScanExprQueryBased, queryTopPlan);
// queryFunctionShuttled maybe sum(column) + count(*), so need to use expression rewriter
Expression rollupedExpression = queryFunctionShuttled.accept(AGGREGATE_EXPRESSION_REWRITER,
context);
if (!context.isValid()) {
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("Query function roll up fail",
String.format("queryFunction = %s,\n queryFunctionShuttled = %s,\n"
+ "mvExprToMvScanExprQueryBased = %s",
queryFunctions.get(0), queryFunctionShuttled,
mvExprToMvScanExprQueryBased)));
String.format("queryFunctionShuttled = %s,\n mvExprToMvScanExprQueryBased = %s",
queryFunctionShuttled, mvExprToMvScanExprQueryBased)));
return null;
}
finalAggregateExpressions.add(new Alias(rollupAggregateFunction));
finalOutputExpressions.add(new Alias(rollupedExpression));
} else {
// if group by expression, try to rewrite group by expression
Expression queryGroupShuttledExpr =
ExpressionUtils.shuttleExpressionWithLineage(topExpression, queryTopPlan);
if (!mvExprToMvScanExprQueryBased.containsKey(queryGroupShuttledExpr)) {
AggregateExpressionRewriteContext context = new AggregateExpressionRewriteContext(
true, mvExprToMvScanExprQueryBased, queryTopPlan);
// group by expression maybe group by a + b, so we need expression rewriter
Expression rewrittenGroupByExpression = queryGroupShuttledExpr.accept(AGGREGATE_EXPRESSION_REWRITER,
context);
if (!context.isValid()) {
// group expr can not rewrite by view
materializationContext.recordFailReason(queryStructInfo.getOriginalPlanId(),
Pair.of("View dimensions doesn't not cover the query dimensions",
String.format("mvExprToMvScanExprQueryBased is %s,\n queryGroupShuttledExpr is %s",
mvExprToMvScanExprQueryBased, queryGroupShuttledExpr)));
return null;
}
Expression expression = mvExprToMvScanExprQueryBased.get(queryGroupShuttledExpr);
finalAggregateExpressions.add((NamedExpression) expression);
finalGroupExpressions.add(expression);
NamedExpression groupByExpression = rewrittenGroupByExpression instanceof NamedExpression
? (NamedExpression) rewrittenGroupByExpression : new Alias(rewrittenGroupByExpression);
finalOutputExpressions.add(groupByExpression);
finalGroupExpressions.add(groupByExpression);
}
}
// add project to guarantee group by column ref is slot reference,
Expand All @@ -229,7 +230,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
return (NamedExpression) expr;
})
.collect(Collectors.toList());
finalAggregateExpressions = finalAggregateExpressions.stream()
finalOutputExpressions = finalOutputExpressions.stream()
.map(expr -> {
ExprId exprId = expr.getExprId();
if (projectOutPutExprIdMap.containsKey(exprId)) {
Expand All @@ -238,7 +239,7 @@ protected Plan rewriteQueryByView(MatchMode matchMode,
return expr;
})
.collect(Collectors.toList());
return new LogicalAggregate(finalGroupExpressions, finalAggregateExpressions, mvProject);
return new LogicalAggregate(finalGroupExpressions, finalOutputExpressions, mvProject);
}

private boolean isGroupByEquals(Pair<Plan, LogicalAggregate<Plan>> queryTopPlanAndAggPair,
Expand Down Expand Up @@ -273,7 +274,7 @@ private boolean isGroupByEquals(Pair<Plan, LogicalAggregate<Plan>> queryTopPlanA
* the queryAggregateFunction is max(a), queryAggregateFunctionShuttled is max(a) + 1
* mvExprToMvScanExprQueryBased is { max(a) : MTMVScan(output#0) }
*/
private Function rollup(AggregateFunction queryAggregateFunction,
private static Function rollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
if (!(queryAggregateFunction instanceof CouldRollUp)) {
Expand Down Expand Up @@ -310,7 +311,7 @@ private Function rollup(AggregateFunction queryAggregateFunction,
// Check the aggregate function can roll up or not, return true if could roll up
// if view aggregate function is distinct or is in the un supported rollup functions, it doesn't support
// roll up.
private boolean canRollup(Expression rollupExpression) {
private static boolean canRollup(Expression rollupExpression) {
if (rollupExpression == null) {
return false;
}
Expand Down Expand Up @@ -402,7 +403,7 @@ protected boolean checkPattern(StructInfo structInfo) {
* This will check the count(distinct a) in query is equivalent to bitmap_union(to_bitmap(a)) in mv,
* and then check their arguments is equivalent.
*/
private boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
private static boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
Expand Down Expand Up @@ -438,9 +439,109 @@ private boolean isAggregateFunctionEquivalent(Function queryFunction, Function v
* actualFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end))
* after extracting, the return argument is: case when a = 5 then 1 else 2 end
*/
private List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
private static List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
Set<Object> exprSetToRemove = functionWithAny.collectToSet(expr -> !(expr instanceof Any));
return actualFunction.collectFirst(expr ->
exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr)));
}

/**
* Aggregate expression rewriter which is responsible for rewriting group by and
* aggregate function expression
*/
protected static class AggregateExpressionRewriter
extends DefaultExpressionRewriter<AggregateExpressionRewriteContext> {

@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return aggregateFunction;
}
Expression queryFunctionShuttled = ExpressionUtils.shuttleExpressionWithLineage(
aggregateFunction,
rewriteContext.getQueryTopPlan());
Function rollupAggregateFunction = rollup(aggregateFunction, queryFunctionShuttled,
rewriteContext.getMvExprToMvScanExprQueryBasedMapping());
if (rollupAggregateFunction == null) {
rewriteContext.setValid(false);
return aggregateFunction;
}
return rollupAggregateFunction;
}

@Override
public Expression visitSlot(Slot slot, AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return slot;
}
if (rewriteContext.getMvExprToMvScanExprQueryBasedMapping().containsKey(slot)) {
return rewriteContext.getMvExprToMvScanExprQueryBasedMapping().get(slot);
}
rewriteContext.setValid(false);
return slot;
}

@Override
public Expression visit(Expression expr, AggregateExpressionRewriteContext rewriteContext) {
if (!rewriteContext.isValid()) {
return expr;
}
// for group by expression try to get corresponding expression directly
if (rewriteContext.isOnlyContainGroupByExpression()
&& rewriteContext.getMvExprToMvScanExprQueryBasedMapping().containsKey(expr)) {
return rewriteContext.getMvExprToMvScanExprQueryBasedMapping().get(expr);
}
List<Expression> newChildren = new ArrayList<>(expr.arity());
boolean hasNewChildren = false;
for (Expression child : expr.children()) {
Expression newChild = child.accept(this, rewriteContext);
if (!rewriteContext.isValid()) {
return expr;
}
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? expr.withChildren(newChildren) : expr;
}
}

/**
* AggregateExpressionRewriteContext
*/
protected static class AggregateExpressionRewriteContext {
private boolean valid = true;
private final boolean onlyContainGroupByExpression;
private final Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping;
private final Plan queryTopPlan;

public AggregateExpressionRewriteContext(boolean onlyContainGroupByExpression,
Map<Expression, Expression> mvExprToMvScanExprQueryBasedMapping, Plan queryTopPlan) {
this.onlyContainGroupByExpression = onlyContainGroupByExpression;
this.mvExprToMvScanExprQueryBasedMapping = mvExprToMvScanExprQueryBasedMapping;
this.queryTopPlan = queryTopPlan;
}

public boolean isValid() {
return valid;
}

public void setValid(boolean valid) {
this.valid = valid;
}

public boolean isOnlyContainGroupByExpression() {
return onlyContainGroupByExpression;
}

public Map<Expression, Expression> getMvExprToMvScanExprQueryBasedMapping() {
return mvExprToMvScanExprQueryBasedMapping;
}

public Plan getQueryTopPlan() {
return queryTopPlan;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,30 @@
2023-12-11 3 43.20 43.20 43.20 1 \N \N 0 1 1
2023-12-12 3 57.40 56.20 1.20 2 \N \N 0 1 1

-- !query25_3_before --
2023-12-08 5 21.00 10.50 9.50 2 \N \N 1 0 1 0
2023-12-09 7 11.50 11.50 11.50 1 \N \N 1 0 1 0
2023-12-10 6 67.00 33.50 12.50 2 \N \N 1 0 1 0
2023-12-11 6 43.20 43.20 43.20 1 \N \N 0 1 1 1
2023-12-12 5 112.40 56.20 1.20 2 \N \N 0 1 1 1

-- !query25_3_after --
2023-12-08 5 21.00 10.50 9.50 2 \N \N 1 0 1 0
2023-12-09 7 11.50 11.50 11.50 1 \N \N 1 0 1 0
2023-12-10 6 67.00 33.50 12.50 2 \N \N 1 0 1 0
2023-12-11 6 43.20 43.20 43.20 1 \N \N 0 1 1 1
2023-12-12 5 112.40 56.20 1.20 2 \N \N 0 1 1 1

-- !query25_4_before --
2 3 2023-12-08 20.00 23.00
2 3 2023-12-12 57.40 60.40
2 4 2023-12-10 46.00 50.00

-- !query25_4_after --
2 3 2023-12-08 20.00 23.00
2 3 2023-12-12 57.40 60.40
2 4 2023-12-10 46.00 50.00

-- !query1_1_before --
1 yy 0 0 11.50 11.50 11.50 1

Expand Down Expand Up @@ -261,6 +285,12 @@
-- !query29_1_after --
0 178.10 1.20 8

-- !query29_2_before --
0 1434.40 1.20

-- !query29_2_after --
0 1434.40 1.20

-- !query30_0_before --
4 4 68 100.0000 36.5000
6 1 0 22.0000 57.2000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@
2 4 2023-12-10 46.00 33.50 12.50 2 0
3 3 2023-12-11 43.20 43.20 43.20 1 0

-- !query15_1_before --
2 20231211 30.50 20.00 10.50 9.50 2 0 2
2 20231214 79.50 46.00 33.50 12.50 2 0 2
2 20231215 113.60 57.40 56.20 1.20 2 0 2
3 20231214 86.40 43.20 43.20 43.20 1 0 1

-- !query15_0_after --
2 20231211 30.50 20.00 10.50 9.50 2 0 2
2 20231214 79.50 46.00 33.50 12.50 2 0 2
2 20231215 113.60 57.40 56.20 1.20 2 0 2
3 20231214 86.40 43.20 43.20 43.20 1 0 1

-- !query16_0_before --
2 3 2023-12-08 20.00 10.50 9.50 2 0
2 3 2023-12-12 57.40 56.20 1.20 2 0
Expand All @@ -99,6 +111,20 @@
3 3 2023-12-11 43.20 43.20 43.20 1 0
4 3 2023-12-09 11.50 11.50 11.50 1 0

-- !query16_1_before --
3 2023-12-08 20.00 10.50 9.50 2 0
3 2023-12-09 11.50 11.50 11.50 1 0
3 2023-12-11 43.20 43.20 43.20 1 0
3 2023-12-12 57.40 56.20 1.20 2 0
4 2023-12-10 46.00 33.50 12.50 2 0

-- !query16_1_after --
3 2023-12-08 20.00 10.50 9.50 2 0
3 2023-12-09 11.50 11.50 11.50 1 0
3 2023-12-11 43.20 43.20 43.20 1 0
3 2023-12-12 57.40 56.20 1.20 2 0
4 2023-12-10 46.00 33.50 12.50 2 0

-- !query17_0_before --
3 3 2023-12-11 43.20 43.20 43.20 1 0

Expand Down Expand Up @@ -177,6 +203,20 @@
2023-12-11 3 3 3 43.20 43.20 43.20 1
2023-12-12 2 3 3 57.40 56.20 1.20 2

-- !query19_3_before --
2 23.00 2023-12-08 20.00 10.50 9.50 29.50 2
2 50.00 2023-12-10 46.00 33.50 12.50 58.50 2
2 60.40 2023-12-12 57.40 56.20 1.20 58.60 2
3 46.20 2023-12-11 43.20 43.20 43.20 86.40 1
4 14.50 2023-12-09 11.50 11.50 11.50 23.00 1

-- !query19_3_after --
2 23.00 2023-12-08 20.00 10.50 9.50 29.50 2
2 50.00 2023-12-10 46.00 33.50 12.50 58.50 2
2 60.40 2023-12-12 57.40 56.20 1.20 58.60 2
3 46.20 2023-12-11 43.20 43.20 43.20 86.40 1
4 14.50 2023-12-09 11.50 11.50 11.50 23.00 1

-- !query20_0_before --
0 0 0 0 0 0 0 0 0 0 0 0

Expand Down
Loading

0 comments on commit 885d125

Please sign in to comment.