Skip to content

Commit

Permalink
[Feature](materialized-view) support ignore not slot is null when cou…
Browse files Browse the repository at this point in the history
…nt(slot) not has key in mv (#32912)

support ignore not slot is null when count(slot) not has key in mv
  • Loading branch information
BiteTheDDDDt authored Apr 1, 2024
1 parent 9be7e30 commit f6de537
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.rewrite.mv;

import org.apache.doris.analysis.CreateMaterializedViewStmt;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.MaterializedIndex;
import org.apache.doris.catalog.MaterializedIndexMeta;
Expand All @@ -33,12 +34,14 @@
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
Expand All @@ -62,6 +65,7 @@
import com.google.common.collect.Lists;
import org.apache.commons.collections.CollectionUtils;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
Expand All @@ -72,6 +76,7 @@
import java.util.TreeSet;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Base class for selecting materialized index rules.
Expand Down Expand Up @@ -109,6 +114,45 @@ protected boolean shouldSelectIndexWithoutAgg(LogicalOlapScan scan) {
}
}

// get the predicates that can be ignored when all aggregate functions are sum
protected static List<Expression> getPrunedPredicatesWithAllSumAgg(List<Expression> aggExpressions,
Set<Expression> predicateExpr) {
List<Expression> prunedExpr = new ArrayList<>();

Set<String> sumSlots = aggExpressions.stream().map(e -> e.child(0).toSql())
.collect(Collectors.toCollection(() -> new TreeSet<String>(String.CASE_INSENSITIVE_ORDER)));
for (Expression expr : predicateExpr) {
if (expr instanceof Not && expr.child(0) instanceof IsNull) {
Expression slot = expr.child(0).child(0);
String countColumn = normalizeName(CreateMaterializedViewStmt.mvColumnBuilder(AggregateType.SUM,
CreateMaterializedViewStmt.mvColumnBuilder(slotToCaseWhen(slot).toSql())));
if (sumSlots.contains(countColumn)) {
prunedExpr.add(expr);
}
}
}
return prunedExpr;
}

// we can prune some predicates when there is no group-by column
protected static List<Expression> getPrunedPredicates(List<Expression> aggExpressions,
Set<Expression> predicateExpr) {
List<Expression> prunedExpr = new ArrayList<>();

boolean isAllSumAgg = true;
for (Expression expr : aggExpressions) {
if (!(expr instanceof Sum)) {
isAllSumAgg = false;
break;
}
}
if (isAllSumAgg) {
prunedExpr.addAll(getPrunedPredicatesWithAllSumAgg(aggExpressions, predicateExpr));
}

return prunedExpr;
}

protected static boolean containAllRequiredColumns(MaterializedIndex index, LogicalOlapScan scan,
Set<Slot> requiredScanOutput, Set<? extends Expression> requiredExpr, Set<Expression> predicateExpr) {
OlapTable table = scan.getTable();
Expand All @@ -121,12 +165,14 @@ protected static boolean containAllRequiredColumns(MaterializedIndex index, Logi
.map(e -> {
e.setDisableTableName(true);
return e;
})
.map(e -> new NereidsParser().parseExpression(e.toSql()).toSql()).collect(Collectors.toSet());
Set<String> commonConjuncts = indexConjuncts.stream().filter(predicateExprSql::contains)
.collect(Collectors.toSet());
if (commonConjuncts.size() != indexConjuncts.size()) {
return false;
}).map(e -> new NereidsParser().parseExpression(e.toSql()).toSql()).collect(Collectors.toSet());

for (String indexConjunct : indexConjuncts) {
if (predicateExprSql.contains(indexConjunct)) {
predicateExprSql.remove(indexConjunct);
} else {
return false;
}
}

Set<String> requiredMvColumnNames = requiredScanOutput.stream()
Expand All @@ -138,10 +184,24 @@ protected static boolean containAllRequiredColumns(MaterializedIndex index, Logi
.collect(Collectors.toCollection(() -> new TreeSet<String>(String.CASE_INSENSITIVE_ORDER)));
mvColNames.addAll(indexConjuncts);

return mvColNames.containsAll(requiredMvColumnNames)
&& (indexConjuncts.isEmpty() || commonConjuncts.size() == predicateExprSql.size())
|| requiredExpr.stream().filter(e -> !containsAllColumn(e, mvColNames)).collect(Collectors.toSet())
.isEmpty();
if (mvColNames.containsAll(requiredMvColumnNames) && predicateExprSql.isEmpty()) {
return true;
}

Set<Expression> remained = requiredExpr.stream().filter(e -> !containsAllColumn(e, mvColNames))
.collect(Collectors.toSet());
if (remained.isEmpty()) {
return true;
}

if (!scan.getGroupExpression().isPresent()) {
Set<Expression> prunedExpr = getPrunedPredicates(
requiredExpr.stream().filter(e -> e instanceof AggregateFunction).collect(Collectors.toList()),
predicateExpr).stream().collect(Collectors.toSet());
remained = remained.stream().filter(e -> !prunedExpr.contains(e)).collect(Collectors.toSet());
}

return remained.isEmpty();
}

public static String parseMvColumnToSql(String mvName) {
Expand Down Expand Up @@ -433,6 +493,21 @@ protected SlotContext generateBaseScanExprToMvExpr(LogicalOlapScan mvPlan) {
.collect(Collectors.toSet()));
}

// Call this generateBaseScanExprToMvExpr only when we have both agg and filter
protected SlotContext generateBaseScanExprToMvExpr(LogicalOlapScan mvPlan, Set<Expression> requiredExpr,
Set<Expression> predicateExpr) {
SlotContext context = generateBaseScanExprToMvExpr(mvPlan);
if (mvPlan.getGroupExpression().isPresent()) {
return context;
}
Set<Expression> pruned = getPrunedPredicates(
requiredExpr.stream().filter(e -> e instanceof AggregateFunction).collect(Collectors.toList()),
predicateExpr).stream().collect(Collectors.toSet());

return new SlotContext(context.baseSlotToMvSlot, context.mvNameToMvSlot,
Stream.concat(pruned.stream(), context.trueExprs.stream()).collect(Collectors.toSet()));
}

/** SlotContext */
protected static class SlotContext {
public static final SlotContext EMPTY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ public List<Rule> buildRules() {
);

LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
.map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
filter.getConjuncts());

return new LogicalProject<>(
generateProjectsAlias(agg.getOutputs(), slotContext),
Expand Down Expand Up @@ -250,7 +252,9 @@ public List<Rule> buildRules() {
);

LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
.map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
filter.getConjuncts());
if (result.indexId == scan.getTable().getBaseIndexId()) {
LogicalOlapScan mvPlanWithoutAgg = SelectMaterializedIndexWithoutAggregate.select(scan,
project::getInputSlots, filter::getConjuncts,
Expand Down Expand Up @@ -311,7 +315,9 @@ public List<Rule> buildRules() {
);

LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
.map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
filter.getConjuncts());

List<NamedExpression> newProjectList = replaceProjectList(project,
result.exprRewriteMap.projectExprMap);
Expand Down Expand Up @@ -390,7 +396,9 @@ public List<Rule> buildRules() {
);

LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
.map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
filter.getConjuncts());

return new LogicalProject<>(
generateProjectsAlias(agg.getOutputs(), slotContext),
Expand Down Expand Up @@ -481,7 +489,9 @@ public List<Rule> buildRules() {
);

LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
.map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
filter.getConjuncts());

List<NamedExpression> newProjectList = replaceProjectList(project,
result.exprRewriteMap.projectExprMap);
Expand Down Expand Up @@ -531,7 +541,9 @@ public List<Rule> buildRules() {
);

LogicalOlapScan mvPlan = createLogicalOlapScan(scan, result);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan);
SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan, requiredExpr.stream()
.map(e -> result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
filter.getConjuncts());

List<NamedExpression> newProjectList = replaceProjectList(project,
result.exprRewriteMap.projectExprMap);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_star --
\N 4 \N d
-4 -4 -4 d
1 1 1 a
2 2 2 b
3 -3 \N c
5 \N \N \N

-- !select_mv --
5

-- !select_mv --
5

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

import org.codehaus.groovy.runtime.IOGroovyMethods

suite ("mv_ignore_predicate") {

sql """ DROP TABLE IF EXISTS d_table; """

sql """
create table d_table(
k1 int null,
k2 int null,
k3 bigint null,
k4 varchar(100) null
)
duplicate key (k1,k2,k3)
distributed BY hash(k1) buckets 3
properties("replication_num" = "1");
"""

sql "insert into d_table select 1,1,1,'a';"
sql "insert into d_table select 2,2,2,'b';"
sql "insert into d_table select 3,-3,null,'c';"

createMV("create materialized view kign as select k1,count(k2) from d_table group by k1;")

sql "insert into d_table select -4,-4,-4,'d';"
sql "insert into d_table(k4,k2) values('d',4);"
sql "insert into d_table select 5,null,null,null;"

qt_select_star "select * from d_table order by k1;"

explain {
sql("select count(k2) from d_table;")
contains "(kign)"
}
qt_select_mv "select count(k2) from d_table;"

explain {
sql("select count(k2) from d_table where k2 is not null;")
contains "(kign)"
}
qt_select_mv "select count(k2) from d_table where k2 is not null;"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import org.codehaus.groovy.runtime.IOGroovyMethods

suite ("test_dup_mv_repeat") {

sql """ DROP TABLE IF EXISTS d_table; """
sql """ DROP TABLE IF EXISTS db1; """

sql """
CREATE TABLE `db1` (
Expand Down

0 comments on commit f6de537

Please sign in to comment.