Skip to content

Commit

Permalink
[feature](nereids) extend infer predicates (apache#40878)
Browse files Browse the repository at this point in the history
This pr refactors the PredicatePropagation module and adds support for predicate deduction, including:
1. Support for predicate deduction of like, not in, !=;
2. Support for predicate deduction of abs(b)=1 for a=b and abs(a)=1; 
3. Support for transitive deduction of non-equivalent relations, for example, a>b b>1 leads to a>1.
4. Deleted useless predicates.

But still has something to do in predicate inference:
1. support expr in infer predicate, e.g. abs(t1.c1)>abs(t2.c2) and abs(t1.c1)<1
2. need to add expr qualifier info, to determine whether abs(t1.c1) and abs(t2.c2) is from same table.
  • Loading branch information
feiniaofeiafei authored and eldenmoon committed Oct 10, 2024
1 parent a802c34 commit 05d3aa0
Show file tree
Hide file tree
Showing 40 changed files with 3,501 additions and 406 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ public List<Long> getIndexIds() {
return indexes.getIndexIds();
}

@Override
public TableIndexes getTableIndexes() {
return indexes;
}
Expand Down
5 changes: 5 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -640,4 +640,9 @@ public long getCachedRowCount() {
public boolean autoAnalyzeEnabled() {
return true;
}

@Override
public TableIndexes getTableIndexes() {
return new TableIndexes();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -561,4 +561,6 @@ default boolean isPartitionedTable() {
}

boolean autoAnalyzeEnabled();

TableIndexes getTableIndexes();
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.TableAttributes;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.catalog.TableIndexes;
import org.apache.doris.catalog.constraint.Constraint;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.Pair;
Expand Down Expand Up @@ -357,4 +358,9 @@ protected Optional<SchemaCacheValue> getSchemaCacheValue() {
ExternalSchemaCache cache = Env.getCurrentEnv().getExtMetaCacheMgr().getSchemaCache(catalog);
return cache.getSchemaValue(dbName, name);
}

@Override
public TableIndexes getTableIndexes() {
return new TableIndexes();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
// eliminate useless not null or inferred not null
// TODO: wait InferPredicates to infer more not null.
bottomUp(new EliminateNotNull()),
topDown(new ConvertInnerOrCrossJoin()),
topDown(new ProjectOtherJoinConditionForNestedLoopJoin())
topDown(new ConvertInnerOrCrossJoin())
),
topic("Set operation optimization",
// Do MergeSetOperation first because we hope to match pattern of Distinct SetOperator.
Expand Down Expand Up @@ -326,7 +325,12 @@ public class Rewriter extends AbstractBatchJobExecutor {
// after eliminate outer join, we can move some filters to join.otherJoinConjuncts,
// this can help to translate plan to backend
topDown(new PushFilterInsideJoin()),
topDown(new FindHashConditionForJoin())
topDown(new FindHashConditionForJoin()),
// ProjectOtherJoinConditionForNestedLoopJoin will push down the expression
// in the non-equivalent join condition and turn it into slotReference,
// This results in the inability to obtain Cast child information in INFER_PREDICATES,
// which will affect predicate inference with cast. So put this rule behind the INFER_PREDICATES
topDown(new ProjectOtherJoinConditionForNestedLoopJoin())
),
// this rule should invoke after ColumnPruning
custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.Like;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import org.apache.doris.nereids.util.PredicateInferUtils;

import com.google.common.collect.ImmutableList;
import org.jetbrains.annotations.Nullable;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**ReplacePredicate*/
public class InferPredicateByReplace {
private static List<Expression> getAllSubExpressions(Expression expr) {
List<Expression> subExpressions = new ArrayList<>();
getAllSubExpressions(expr, subExpressions);
return subExpressions;
}

private static void getAllSubExpressions(Expression expr, List<Expression> res) {
res.add(expr);
if (expr.children().size() != 1) {
Set<Slot> slots = expr.getInputSlots();
if (slots.size() == 1) {
res.add(slots.iterator().next());
}
return;
}
getAllSubExpressions(expr.child(0), res);
}

/** fill map exprPredicates : expression and all its corresponding predicates */
private static class PredicatesCollector extends ExpressionVisitor<Void, Map<Expression, Set<Expression>>> {
public static PredicatesCollector INSTANCE = new PredicatesCollector();

@Override
public Void visit(Expression expr, Map<Expression, Set<Expression>> context) {
return null;
}

@Override
public Void visitOr(Or expr, Map<Expression, Set<Expression>> context) {
return null;
}

@Override
public Void visitInPredicate(InPredicate inPredicate, Map<Expression, Set<Expression>> context) {
if (!validInPredicate(inPredicate)) {
return null;
}
for (Expression expr : getAllSubExpressions(inPredicate.getCompareExpr())) {
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(inPredicate);
}
return null;
}

@Override
public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate,
Map<Expression, Set<Expression>> context) {
if (!validComparisonPredicate(comparisonPredicate)) {
return null;
}
// It is believed that 1<a has been rewritten as a>1
for (Expression expr : getAllSubExpressions(comparisonPredicate.child(0))) {
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(comparisonPredicate);
}
return null;
}

@Override
public Void visitNot(Not not, Map<Expression, Set<Expression>> context) {
if (not.child(0) instanceof InPredicate && validInPredicate((InPredicate) not.child(0))
|| not.child(0) instanceof ComparisonPredicate
&& validComparisonPredicate((ComparisonPredicate) not.child(0))) {
for (Expression expr : getAllSubExpressions(not.child(0).child(0))) {
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(not);
}
}
return null;
}

@Override
public Void visitLike(Like like, Map<Expression, Set<Expression>> context) {
if (!(like.child(1) instanceof Literal)) {
return null;
}
for (Expression expr : getAllSubExpressions(like.child(0))) {
context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(like);
}
return null;
}

private boolean validComparisonPredicate(ComparisonPredicate comparisonPredicate) {
return comparisonPredicate.right() instanceof Literal;
}

private boolean validInPredicate(InPredicate inPredicate) {
return inPredicate.isLiteralChildren();
}
}

/* replaceToThis: find all predicates that replaceToThis can deduce (e.g. replaceToThis = b)
equalSet: the equivalent set of replaceToThis (e.g. equalSet: a=b)
exprPredicates: expression and all its corresponding predicates (e.g. such as {a: [a<10, a>1], b: [b in (1, 2)]})
return: all predicates that replaceToThis can deduce (return b<10, b>1) */
private static <T extends Expression> Set<Expression> getEqualSetAndDoReplace(T replaceToThis, Set<T> equalSet,
Map<? extends Expression, Set<Expression>> exprPredicates) {
ExpressionAnalyzer analyzer = new ReplaceAnalyzer(null, new Scope(ImmutableList.of()), null, false, false);
Set<Expression> res = new LinkedHashSet<>();
for (T equals : equalSet) {
Map<Expression, Expression> replaceMap = new HashMap<>();
replaceMap.put(equals, replaceToThis);
if (!exprPredicates.containsKey(equals)) {
continue;
}
for (Expression predicate : exprPredicates.get(equals)) {
Expression newPredicates = ExpressionUtils.replace(predicate, replaceMap);
try {
Expression analyzed = analyzer.analyze(newPredicates);
res.add(analyzed.withInferred(true));
} catch (Exception e) {
// has cast error, just not infer and do nothing
}
}
}
return res;
}

/* Extract the equivalence relationship a=b, and when case (d_tinyint as int)=d_int is encountered,
remove the cast and extract d_tinyint=d_int
EqualPairs is the output parameter and the equivalent pair of predicate derivation input,
which is used to ensure that the derivation
does not generate repeated equivalent conditions, such as a=b and b=a */
private static ImmutableEqualSet<Expression> findEqual(Set<Expression> inputs) {
ImmutableEqualSet.Builder<Expression> fromCastEqualSetBuilder = new ImmutableEqualSet.Builder<>();
for (Expression input : inputs) {
if (!(input instanceof EqualTo)) {
continue;
}
EqualTo equalTo = (EqualTo) input;
Set<Slot> leftInputSlots = equalTo.left().getInputSlots();
Set<Slot> rightInputSlots = equalTo.right().getInputSlots();
if (leftInputSlots.isEmpty() && rightInputSlots.isEmpty()) {
continue;
}
PredicateInferUtils.getPairFromCast((ComparisonPredicate) input)
.filter(pair -> PredicateInferUtils.isSlotOrLiteral(pair.first)
&& PredicateInferUtils.isSlotOrLiteral(pair.second))
.filter(pair -> !(pair.first instanceof NullLiteral) && !(pair.second instanceof NullLiteral))
.ifPresent(pair -> {
Expression left = pair.first;
Expression right = pair.second;
fromCastEqualSetBuilder.addEqualPair(left, right);
});
}
return fromCastEqualSetBuilder.build();
}

/** This is the exposed interface. Inputs are the input predicates for derivation.
* The return value is the derived predicates*/
public static Set<Expression> infer(Set<Expression> inputs) {
ImmutableEqualSet<Expression> hasCastEqualSet = findEqual(inputs);
Set<Expression> targetExprs = hasCastEqualSet.getAllItemSet();
if (targetExprs.isEmpty()) {
return new LinkedHashSet<>(inputs);
}
Map<Expression, Set<Expression>> exprPredicates = new HashMap<>();
for (Expression input : inputs) {
if (input.anyMatch(expr -> !((ExpressionTrait) expr).isDeterministic())
|| input.getInputSlots().size() != 1) {
continue;
}
input.accept(PredicatesCollector.INSTANCE, exprPredicates);
}
Set<Expression> inferPredicates = new LinkedHashSet<>(inputs);
if (!exprPredicates.isEmpty()) {
for (Expression expr : targetExprs) {
if (expr instanceof Literal) {
continue;
}
inferPredicates.addAll(getEqualSetAndDoReplace(expr, hasCastEqualSet.calEqualSet(expr),
exprPredicates));
}
}
return inferPredicates;
}

/** ReplaceAnalyzer is to perform type conversion on the expression after replacement
* and perform type check on the expression.
* If there is a cast that will cause an error during execution, an exception should be thrown. */
private static class ReplaceAnalyzer extends ExpressionAnalyzer {
private ReplaceAnalyzer(Plan currentPlan, Scope scope,
@Nullable CascadesContext cascadesContext,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
super(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope);
}

@Override
public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
cast = (Cast) super.visitCast(cast, context);
if (cast.getDataType().isDecimalV3Type()) {
DecimalV3Type targetType = (DecimalV3Type) cast.getDataType();
DecimalV3Type childType = DecimalV3Type.forType(cast.child().getDataType());
if ((childType.getPrecision() - childType.getScale())
> (targetType.getPrecision() - targetType.getScale())
|| childType.getScale() > targetType.getScale()) {
throw new AnalysisException("can not cast from origin type " + cast.child().getDataType()
+ " to target type=" + targetType);
}
} else if (cast.getDataType().isDecimalV2Type()) {
DecimalV2Type targetType = (DecimalV2Type) cast.getDataType();
DecimalV2Type childType = DecimalV2Type.forType(cast.child().getDataType());
if ((childType.getPrecision() - childType.getScale())
> (targetType.getPrecision() - targetType.getScale())
|| childType.getScale() > targetType.getScale()) {
throw new AnalysisException("can not cast from origin type " + cast.child().getDataType()
+ " to target type=" + targetType);
}
}
return cast;
}
}
}
Loading

0 comments on commit 05d3aa0

Please sign in to comment.