Skip to content

Commit

Permalink
merge equal and unequal predicate infer
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Sep 19, 2024
1 parent 0e6b17f commit db0de10
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 236 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.PredicateInferUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -157,19 +157,7 @@ private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expres
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
// Set<Expression> newExpressions = new HashSet<>();
// newExpressions.addAll(PredicatePropagation.infer(baseExpressions));
// newExpressions.addAll(NonEqualPredicateInfer.inferUnequalPredicates(baseExpressions));

Set<Expression> inferPredicates = new HashSet<>();
Set<Expression> complexPredicates = new HashSet<>();
Set<Expression> simplePredicates = new HashSet<>();
Set<Expression> tmp = ReplacePredicate.infer(baseExpressions);
tmp.addAll(baseExpressions);
ExpressionUtils.getComplexAndSimplePredicates(tmp, complexPredicates, simplePredicates);
inferPredicates.addAll(complexPredicates);
inferPredicates.addAll(NonEqualPredicateInfer.inferUnequalPredicates(simplePredicates));
return inferPredicates;
return PredicateInferUtils.inferPredicate(baseExpressions);
}

private Set<Expression> pullUpPredicates(Plan plan) {
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PredicateInferUtils;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSet.Builder;
Expand Down Expand Up @@ -226,23 +227,10 @@ private ImmutableSet<Expression> getAvailableExpressions(Set<Expression> predica
if (predicates.isEmpty()) {
return ImmutableSet.of();
}
Set<Expression> inferPredicates = new HashSet<>();
Set<Expression> complexPredicates = new HashSet<>();
Set<Expression> simplePredicates = new HashSet<>();
Set<Expression> tmp = ReplacePredicate.infer(predicates);
tmp.addAll(predicates);
ExpressionUtils.getComplexAndSimplePredicates(tmp, complexPredicates, simplePredicates);
inferPredicates.addAll(complexPredicates);
inferPredicates.addAll(NonEqualPredicateInfer.inferUnequalPredicates(simplePredicates));
Set<Expression> inferPredicates = PredicateInferUtils.inferPredicate(predicates);
Builder<Expression> newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size() + 10);
Set<Slot> outputSet = plan.getOutputSet();

// for (Expression predicate : predicates) {
// if (outputSet.containsAll(predicate.getInputSlots())) {
// newPredicates.add(predicate);
// }
// }

for (Expression inferPredicate : inferPredicates) {
if (outputSet.containsAll(inferPredicate.getInputSlots())) {
newPredicates.add(inferPredicate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.rules.AbstractEqualSet;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
Expand All @@ -46,7 +45,6 @@
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.Comparator;
import java.util.HashMap;
Expand Down Expand Up @@ -188,8 +186,8 @@ private static <T extends Expression> Set<Expression> getEqualSetAndDoReplace(T
EqualPairs is the output parameter and the equivalent pair of predicate derivation input,
which is used to ensure that the derivation
does not generate repeated equivalent conditions, such as a=b and b=a */
private static AbstractEqualSet<Slot> findEqual(Set<Expression> inputs, Set<Pair<Slot, Slot>> equalPairs) {
AbstractEqualSet.Builder<Slot> fromCastEqualSetBuilder = new ImmutableEqualSet.Builder<>();
private static ImmutableEqualSet<Slot> findEqual(Set<Expression> inputs, Set<Pair<Slot, Slot>> equalPairs) {
ImmutableEqualSet.Builder<Slot> fromCastEqualSetBuilder = new ImmutableEqualSet.Builder<>();
for (Expression input : inputs) {
if (!(input instanceof EqualTo)) {
continue;
Expand Down Expand Up @@ -219,10 +217,10 @@ private static AbstractEqualSet<Slot> findEqual(Set<Expression> inputs, Set<Pair
* The return value is the derived predicates*/
public static Set<Expression> infer(Set<Expression> inputs) {
Set<Pair<Slot, Slot>> equalPairs = new HashSet<>();
AbstractEqualSet<Slot> hasCastEqualSet = findEqual(inputs, equalPairs);
ImmutableEqualSet<Slot> hasCastEqualSet = findEqual(inputs, equalPairs);
Set<Slot> targetExprs = hasCastEqualSet.getAllItemSet();
if (targetExprs.isEmpty()) {
return ImmutableSet.of();
return new HashSet<>();
}
Map<Expression, Set<Expression>> exprPredicates = new HashMap<>();
for (Expression input : inputs) {
Expand Down Expand Up @@ -326,7 +324,7 @@ private static Optional<Expression> validForInfer(Expression expression, InferTy
}

/* This function is used to input a=b b=c to derive a=c, and return a=c.*/
private static Set<Expression> deduceTransitiveEquality(AbstractEqualSet<Slot> equalSet,
private static Set<Expression> deduceTransitiveEquality(ImmutableEqualSet<Slot> equalSet,
Set<Pair<Slot, Slot>> equalPairs) {
List<Set<Slot>> equalSetList = equalSet.calEqualSetList();
Set<Expression> derivedEqualities = new HashSet<>();
Expand Down Expand Up @@ -359,5 +357,4 @@ private static boolean isSingleTableExpression(Expression expr) {
}
return qualifiers.size() == 1;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,11 @@
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
Expand Down Expand Up @@ -955,37 +950,4 @@ public static boolean unionConstExprsSatisfyConjuncts(LogicalUnion union, Set<Ex
}
return true;
}

public static boolean isSlotOrLiteral(Expression expr) {
return expr instanceof SlotReference || expr instanceof Literal;
}

public static void getComplexAndSimplePredicates(Set<Expression> inputs, Set<Expression> complex, Set<Expression> simple) {
for (Expression input : inputs) {
if (input instanceof ComparisonPredicate && !(input instanceof NullSafeEqual)) {
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) input;
if (comparisonPredicate.left().equals(comparisonPredicate.right())) {
complex.add(input);
}
Set<Slot> leftSlots = comparisonPredicate.left().getInputSlots();
Set<Slot> rightSlots = comparisonPredicate.right().getInputSlots();
if (leftSlots.isEmpty() && rightSlots.isEmpty()) {
complex.add(input);
}
if (!isSlotOrLiteral(comparisonPredicate.left()) || !isSlotOrLiteral(comparisonPredicate.right())) {
complex.add(input);
}
if (comparisonPredicate instanceof LessThan || comparisonPredicate instanceof LessThanEqual) {
simple.add(comparisonPredicate.commute());
} else if (comparisonPredicate instanceof GreaterThan || comparisonPredicate instanceof GreaterThanEqual
|| comparisonPredicate instanceof EqualTo) {
simple.add(comparisonPredicate);
} else {
complex.add(input);
}
} else {
complex.add(input);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.doris.nereids.util;

import org.apache.doris.nereids.rules.AbstractEqualSet;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand All @@ -34,7 +32,7 @@
/**
* A class representing an immutable set of elements with equivalence relations.
*/
public class ImmutableEqualSet<T> implements AbstractEqualSet<T> {
public class ImmutableEqualSet<T> {
private final Map<T, T> root;

ImmutableEqualSet(Map<T, T> root) {
Expand All @@ -48,7 +46,7 @@ public static <T> ImmutableEqualSet<T> empty() {
/**
* Builder for ImmutableEqualSet.
*/
public static class Builder<T> extends AbstractEqualSet.Builder<T> {
public static class Builder<T> {
private Map<T, T> parent;

Builder(Map<T, T> parent) {
Expand Down Expand Up @@ -140,7 +138,6 @@ private T findRoot(T a) {
return findRoot(parent.get(a));
}

@Override
public ImmutableEqualSet<T> build() {
ImmutableMap.Builder<T, T> foldMapBuilder = new ImmutableMap.Builder<>();
for (T k : parent.keySet()) {
Expand All @@ -153,7 +150,6 @@ public ImmutableEqualSet<T> build() {
/**
* Calculate equal set for a except self
*/
@Override
public Set<T> calEqualSet(T a) {
T ra = root.get(a);
return root.keySet().stream()
Expand All @@ -168,7 +164,6 @@ public boolean isEmpty() {
/**
* Calculate all equal set
*/
@Override
public List<Set<T>> calEqualSetList() {
return root.values()
.stream()
Expand All @@ -181,7 +176,6 @@ public List<Set<T>> calEqualSetList() {
}).collect(ImmutableList.toImmutableList());
}

@Override
public Set<T> getAllItemSet() {
return ImmutableSet.copyOf(root.keySet());
}
Expand All @@ -193,7 +187,6 @@ public boolean isEqual(T l, T r) {
return root.get(l) == root.get(r);
}

@Override
public Set<Map.Entry<T, T>> getAllPairs() {
return root.entrySet();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.util;

import org.apache.doris.nereids.rules.rewrite.NonEqualPredicateInfer;
import org.apache.doris.nereids.rules.rewrite.ReplacePredicate;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;

import java.util.HashSet;
import java.util.Set;

/** PredicateInferUtils */
public class PredicateInferUtils {
public static boolean isSlotOrLiteral(Expression expr) {
return expr instanceof SlotReference || expr instanceof Literal;
}

/**The inputs predicate is divided into two parts. One is the predicate directly reserved, which does not enter
* the non equivalent derivation, and the other is the predicates entering the non equivalent derivation*/
public static void getComplexAndSimplePredicates(Set<Expression> inputs, Set<Expression> complex,
Set<ComparisonPredicate> simple) {
for (Expression input : inputs) {
if (input instanceof ComparisonPredicate && !(input instanceof NullSafeEqual)) {
ComparisonPredicate comparisonPredicate = (ComparisonPredicate) input;
if (comparisonPredicate.left().equals(comparisonPredicate.right())) {
complex.add(input);
}
Set<Slot> leftSlots = comparisonPredicate.left().getInputSlots();
Set<Slot> rightSlots = comparisonPredicate.right().getInputSlots();
if (leftSlots.isEmpty() && rightSlots.isEmpty()) {
complex.add(input);
}
if (!isSlotOrLiteral(comparisonPredicate.left()) || !isSlotOrLiteral(comparisonPredicate.right())) {
complex.add(input);
}
if (comparisonPredicate instanceof GreaterThan || comparisonPredicate instanceof GreaterThanEqual
|| comparisonPredicate instanceof EqualTo || comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof LessThanEqual) {
simple.add(comparisonPredicate);
} else {
complex.add(input);
}
} else {
complex.add(input);
}
}
}

/**The predicate derivation is based on the input predicate predicates, which is divided into two parts.
* The equivalent relation used in ReplacePredicate and calculated by union-find derive like, in, not
* and ComparisonPredicate;
* The NonEqualPredicateInfer class deduces predicates based on non-equal relations, and deletes
* the useless ComparisonPredicates derived from ReplacePredicate*/
public static Set<Expression> inferPredicate(Set<Expression> predicates) {
Set<Expression> inferPredicates = new HashSet<>();
Set<Expression> complexPredicates = new HashSet<>();
Set<ComparisonPredicate> simplePredicates = new HashSet<>();
Set<Expression> inferAndOriginPredicates = ReplacePredicate.infer(predicates);
inferAndOriginPredicates.addAll(predicates);
PredicateInferUtils.getComplexAndSimplePredicates(inferAndOriginPredicates, complexPredicates,
simplePredicates);
inferPredicates.addAll(complexPredicates);
inferPredicates.addAll(NonEqualPredicateInfer.inferUnequalPredicates(simplePredicates));
return inferPredicates;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public void testInferWithTransitiveEqualityWithCast() {

EqualTo expected = new EqualTo(a, b);
Assertions.assertTrue(result.contains(expected) || result.contains(expected.commute()),
"Expected to find a = b in the result.");
"Expected to find a = c in the result.");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ PhysicalResultSink
PhysicalResultSink
--NestedLoopJoin[CROSS_JOIN]
----PhysicalOlapScan[test_like2]
----filter((t1.a = t1.c) and (t1.a > 1) and (t1.c > 1))
----filter((t1.a = t1.c) and (t1.a > 1))
------PhysicalOlapScan[test_like1]

-- !infer_long_chain_diff_table --
Expand Down

0 comments on commit db0de10

Please sign in to comment.