From 05d3aa0028d2bc37ab9b49ab1ff78e00c823ed47 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei <53502832+feiniaofeiafei@users.noreply.github.com> Date: Thu, 10 Oct 2024 15:49:16 +0800 Subject: [PATCH] [feature](nereids) extend infer predicates (#40878) This pr refactors the PredicatePropagation module and adds support for predicate deduction, including: 1. Support for predicate deduction of like, not in, !=; 2. Support for predicate deduction of abs(b)=1 for a=b and abs(a)=1; 3. Support for transitive deduction of non-equivalent relations, for example, a>b b>1 leads to a>1. 4. Deleted useless predicates. But still has something to do in predicate inference: 1. support expr in infer predicate, e.g. abs(t1.c1)>abs(t2.c2) and abs(t1.c1)<1 2. need to add expr qualifier info, to determine whether abs(t1.c1) and abs(t2.c2) is from same table. --- .../org/apache/doris/catalog/OlapTable.java | 1 + .../java/org/apache/doris/catalog/Table.java | 5 + .../org/apache/doris/catalog/TableIf.java | 2 + .../doris/datasource/ExternalTable.java | 6 + .../doris/nereids/jobs/executor/Rewriter.java | 10 +- .../rewrite/InferPredicateByReplace.java | 266 +++++++ .../rules/rewrite/InferPredicates.java | 46 +- .../rules/rewrite/PredicatePropagation.java | 251 ------- .../rules/rewrite/PullUpPredicates.java | 95 ++- .../rules/rewrite/UnequalPredicateInfer.java | 576 +++++++++++++++ .../doris/nereids/trees/expressions/Like.java | 14 +- .../doris/nereids/trees/expressions/Not.java | 10 + .../expressions/StringRegexPredicate.java | 6 +- .../expressions/functions/BoundFunction.java | 6 +- .../trees/expressions/functions/Function.java | 6 +- .../functions/scalar/ScalarFunction.java | 6 +- .../nereids/util/PredicateInferUtils.java | 179 +++++ .../doris/nereids/properties/UniformTest.java | 4 + .../rewrite/InferPredicateByReplaceTest.java | 203 ++++++ .../rewrite/PredicatePropagationTest.java | 67 -- .../rewrite/UnequalPredicateInferTest.java | 688 ++++++++++++++++++ .../org/apache/doris/policy/PolicyTest.java | 12 +- .../data/nereids_hint_tpch_p0/shape/q12.out | 2 +- .../extend_infer_equal_predicate.out | 686 +++++++++++++++++ .../infer_unequal_predicates.out | 165 +++++ .../predicate_infer/infer_predicate.out | 11 +- .../nostats_rf_prune/q12.out | 2 +- .../rf_prune/q12.out | 2 +- .../shape/q12.out | 2 +- .../shape_no_stats/q12.out | 2 +- .../new_shapes_p0/hint_tpch/shape/q12.out | 2 +- .../tpch_sf1000/nostats_rf_prune/q12.out | 2 +- .../tpch_sf1000/rf_prune/q12.out | 2 +- .../new_shapes_p0/tpch_sf1000/shape/q12.out | 2 +- .../tpch_sf1000/shape_no_stats/q12.out | 2 +- .../infer_predicate/infer_predicate.groovy | 14 +- .../extend_infer_equal_predicate.groovy | 357 +++++++++ .../infer_unequal_predicates.groovy | 189 +++++ .../union_all_compensate.groovy | 2 +- .../test_multi_range_partition.groovy | 4 +- 40 files changed, 3501 insertions(+), 406 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java delete mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnequalPredicateInfer.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/util/PredicateInferUtils.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java delete mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/UnequalPredicateInferTest.java create mode 100644 regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out create mode 100644 regression-test/data/nereids_rules_p0/infer_predicate/infer_unequal_predicates.out create mode 100644 regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy create mode 100644 regression-test/suites/nereids_rules_p0/infer_predicate/infer_unequal_predicates.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java index ddbb6f918091c4..e4b61dd4a8c4f9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java @@ -343,6 +343,7 @@ public List getIndexIds() { return indexes.getIndexIds(); } + @Override public TableIndexes getTableIndexes() { return indexes; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java index 234128582fb68f..98cd82902912d0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Table.java @@ -640,4 +640,9 @@ public long getCachedRowCount() { public boolean autoAnalyzeEnabled() { return true; } + + @Override + public TableIndexes getTableIndexes() { + return new TableIndexes(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/TableIf.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/TableIf.java index ed40840239a3ed..3a688a7b59d17a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/TableIf.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/TableIf.java @@ -561,4 +561,6 @@ default boolean isPartitionedTable() { } boolean autoAnalyzeEnabled(); + + TableIndexes getTableIndexes(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/ExternalTable.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/ExternalTable.java index eedbe4e20da312..f0c17da4265095 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/datasource/ExternalTable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/ExternalTable.java @@ -22,6 +22,7 @@ import org.apache.doris.catalog.Env; import org.apache.doris.catalog.TableAttributes; import org.apache.doris.catalog.TableIf; +import org.apache.doris.catalog.TableIndexes; import org.apache.doris.catalog.constraint.Constraint; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Pair; @@ -357,4 +358,9 @@ protected Optional getSchemaCacheValue() { ExternalSchemaCache cache = Env.getCurrentEnv().getExtMetaCacheMgr().getSchemaCache(catalog); return cache.getSchemaValue(dbName, name); } + + @Override + public TableIndexes getTableIndexes() { + return new TableIndexes(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index d322be75cbb7ca..51c5045aa1f7ac 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -295,8 +295,7 @@ public class Rewriter extends AbstractBatchJobExecutor { // eliminate useless not null or inferred not null // TODO: wait InferPredicates to infer more not null. bottomUp(new EliminateNotNull()), - topDown(new ConvertInnerOrCrossJoin()), - topDown(new ProjectOtherJoinConditionForNestedLoopJoin()) + topDown(new ConvertInnerOrCrossJoin()) ), topic("Set operation optimization", // Do MergeSetOperation first because we hope to match pattern of Distinct SetOperator. @@ -326,7 +325,12 @@ public class Rewriter extends AbstractBatchJobExecutor { // after eliminate outer join, we can move some filters to join.otherJoinConjuncts, // this can help to translate plan to backend topDown(new PushFilterInsideJoin()), - topDown(new FindHashConditionForJoin()) + topDown(new FindHashConditionForJoin()), + // ProjectOtherJoinConditionForNestedLoopJoin will push down the expression + // in the non-equivalent join condition and turn it into slotReference, + // This results in the inability to obtain Cast child information in INFER_PREDICATES, + // which will affect predicate inference with cast. So put this rule behind the INFER_PREDICATES + topDown(new ProjectOtherJoinConditionForNestedLoopJoin()) ), // this rule should invoke after ColumnPruning custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java new file mode 100644 index 00000000000000..d6f4925c7adeb7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplace.java @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.analyzer.Scope; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.Like; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.ImmutableEqualSet; +import org.apache.doris.nereids.util.PredicateInferUtils; + +import com.google.common.collect.ImmutableList; +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/**ReplacePredicate*/ +public class InferPredicateByReplace { + private static List getAllSubExpressions(Expression expr) { + List subExpressions = new ArrayList<>(); + getAllSubExpressions(expr, subExpressions); + return subExpressions; + } + + private static void getAllSubExpressions(Expression expr, List res) { + res.add(expr); + if (expr.children().size() != 1) { + Set slots = expr.getInputSlots(); + if (slots.size() == 1) { + res.add(slots.iterator().next()); + } + return; + } + getAllSubExpressions(expr.child(0), res); + } + + /** fill map exprPredicates : expression and all its corresponding predicates */ + private static class PredicatesCollector extends ExpressionVisitor>> { + public static PredicatesCollector INSTANCE = new PredicatesCollector(); + + @Override + public Void visit(Expression expr, Map> context) { + return null; + } + + @Override + public Void visitOr(Or expr, Map> context) { + return null; + } + + @Override + public Void visitInPredicate(InPredicate inPredicate, Map> context) { + if (!validInPredicate(inPredicate)) { + return null; + } + for (Expression expr : getAllSubExpressions(inPredicate.getCompareExpr())) { + context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(inPredicate); + } + return null; + } + + @Override + public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate, + Map> context) { + if (!validComparisonPredicate(comparisonPredicate)) { + return null; + } + // It is believed that 11 + for (Expression expr : getAllSubExpressions(comparisonPredicate.child(0))) { + context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(comparisonPredicate); + } + return null; + } + + @Override + public Void visitNot(Not not, Map> context) { + if (not.child(0) instanceof InPredicate && validInPredicate((InPredicate) not.child(0)) + || not.child(0) instanceof ComparisonPredicate + && validComparisonPredicate((ComparisonPredicate) not.child(0))) { + for (Expression expr : getAllSubExpressions(not.child(0).child(0))) { + context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(not); + } + } + return null; + } + + @Override + public Void visitLike(Like like, Map> context) { + if (!(like.child(1) instanceof Literal)) { + return null; + } + for (Expression expr : getAllSubExpressions(like.child(0))) { + context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(like); + } + return null; + } + + private boolean validComparisonPredicate(ComparisonPredicate comparisonPredicate) { + return comparisonPredicate.right() instanceof Literal; + } + + private boolean validInPredicate(InPredicate inPredicate) { + return inPredicate.isLiteralChildren(); + } + } + + /* replaceToThis: find all predicates that replaceToThis can deduce (e.g. replaceToThis = b) + equalSet: the equivalent set of replaceToThis (e.g. equalSet: a=b) + exprPredicates: expression and all its corresponding predicates (e.g. such as {a: [a<10, a>1], b: [b in (1, 2)]}) + return: all predicates that replaceToThis can deduce (return b<10, b>1) */ + private static Set getEqualSetAndDoReplace(T replaceToThis, Set equalSet, + Map> exprPredicates) { + ExpressionAnalyzer analyzer = new ReplaceAnalyzer(null, new Scope(ImmutableList.of()), null, false, false); + Set res = new LinkedHashSet<>(); + for (T equals : equalSet) { + Map replaceMap = new HashMap<>(); + replaceMap.put(equals, replaceToThis); + if (!exprPredicates.containsKey(equals)) { + continue; + } + for (Expression predicate : exprPredicates.get(equals)) { + Expression newPredicates = ExpressionUtils.replace(predicate, replaceMap); + try { + Expression analyzed = analyzer.analyze(newPredicates); + res.add(analyzed.withInferred(true)); + } catch (Exception e) { + // has cast error, just not infer and do nothing + } + } + } + return res; + } + + /* Extract the equivalence relationship a=b, and when case (d_tinyint as int)=d_int is encountered, + remove the cast and extract d_tinyint=d_int + EqualPairs is the output parameter and the equivalent pair of predicate derivation input, + which is used to ensure that the derivation + does not generate repeated equivalent conditions, such as a=b and b=a */ + private static ImmutableEqualSet findEqual(Set inputs) { + ImmutableEqualSet.Builder fromCastEqualSetBuilder = new ImmutableEqualSet.Builder<>(); + for (Expression input : inputs) { + if (!(input instanceof EqualTo)) { + continue; + } + EqualTo equalTo = (EqualTo) input; + Set leftInputSlots = equalTo.left().getInputSlots(); + Set rightInputSlots = equalTo.right().getInputSlots(); + if (leftInputSlots.isEmpty() && rightInputSlots.isEmpty()) { + continue; + } + PredicateInferUtils.getPairFromCast((ComparisonPredicate) input) + .filter(pair -> PredicateInferUtils.isSlotOrLiteral(pair.first) + && PredicateInferUtils.isSlotOrLiteral(pair.second)) + .filter(pair -> !(pair.first instanceof NullLiteral) && !(pair.second instanceof NullLiteral)) + .ifPresent(pair -> { + Expression left = pair.first; + Expression right = pair.second; + fromCastEqualSetBuilder.addEqualPair(left, right); + }); + } + return fromCastEqualSetBuilder.build(); + } + + /** This is the exposed interface. Inputs are the input predicates for derivation. + * The return value is the derived predicates*/ + public static Set infer(Set inputs) { + ImmutableEqualSet hasCastEqualSet = findEqual(inputs); + Set targetExprs = hasCastEqualSet.getAllItemSet(); + if (targetExprs.isEmpty()) { + return new LinkedHashSet<>(inputs); + } + Map> exprPredicates = new HashMap<>(); + for (Expression input : inputs) { + if (input.anyMatch(expr -> !((ExpressionTrait) expr).isDeterministic()) + || input.getInputSlots().size() != 1) { + continue; + } + input.accept(PredicatesCollector.INSTANCE, exprPredicates); + } + Set inferPredicates = new LinkedHashSet<>(inputs); + if (!exprPredicates.isEmpty()) { + for (Expression expr : targetExprs) { + if (expr instanceof Literal) { + continue; + } + inferPredicates.addAll(getEqualSetAndDoReplace(expr, hasCastEqualSet.calEqualSet(expr), + exprPredicates)); + } + } + return inferPredicates; + } + + /** ReplaceAnalyzer is to perform type conversion on the expression after replacement + * and perform type check on the expression. + * If there is a cast that will cause an error during execution, an exception should be thrown. */ + private static class ReplaceAnalyzer extends ExpressionAnalyzer { + private ReplaceAnalyzer(Plan currentPlan, Scope scope, + @Nullable CascadesContext cascadesContext, + boolean enableExactMatch, boolean bindSlotInOuterScope) { + super(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope); + } + + @Override + public Expression visitCast(Cast cast, ExpressionRewriteContext context) { + cast = (Cast) super.visitCast(cast, context); + if (cast.getDataType().isDecimalV3Type()) { + DecimalV3Type targetType = (DecimalV3Type) cast.getDataType(); + DecimalV3Type childType = DecimalV3Type.forType(cast.child().getDataType()); + if ((childType.getPrecision() - childType.getScale()) + > (targetType.getPrecision() - targetType.getScale()) + || childType.getScale() > targetType.getScale()) { + throw new AnalysisException("can not cast from origin type " + cast.child().getDataType() + + " to target type=" + targetType); + } + } else if (cast.getDataType().isDecimalV2Type()) { + DecimalV2Type targetType = (DecimalV2Type) cast.getDataType(); + DecimalV2Type childType = DecimalV2Type.forType(cast.child().getDataType()); + if ((childType.getPrecision() - childType.getScale()) + > (targetType.getPrecision() - targetType.getScale()) + || childType.getScale() > targetType.getScale()) { + throw new AnalysisException("can not cast from origin type " + cast.child().getDataType() + + " to target type=" + targetType); + } + } + return cast; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java index 5256c7744b9837..98fd368b30e076 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferPredicates.java @@ -17,9 +17,11 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.mysql.MysqlCommand; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalExcept; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; @@ -29,16 +31,18 @@ import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.PredicateInferUtils; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * infer additional predicates for `LogicalFilter` and `LogicalJoin`. @@ -58,10 +62,17 @@ * */ public class InferPredicates extends DefaultPlanRewriter implements CustomRewriter { - private final PullUpPredicates pollUpPredicates = new PullUpPredicates(); + private final PullUpPredicates pullUpPredicates = new PullUpPredicates(false); + // The role of pullUpAllPredicates is to prevent inference of redundant predicates + private final PullUpPredicates pullUpAllPredicates = new PullUpPredicates(true); @Override public Plan rewriteRoot(Plan plan, JobContext jobContext) { + // Preparing stmt requires that the predicate cannot be changed, so no predicate inference is performed. + ConnectContext connectContext = jobContext.getCascadesContext().getConnectContext(); + if (connectContext != null && connectContext.getCommand() == MysqlCommand.COM_STMT_PREPARE) { + return plan; + } return plan.accept(this, jobContext); } @@ -104,13 +115,8 @@ public Plan visitLogicalJoin(LogicalJoin join, J public Plan visitLogicalFilter(LogicalFilter filter, JobContext context) { filter = visitChildren(this, filter, context); Set filterPredicates = pullUpPredicates(filter); - filterPredicates.removeAll(pullUpPredicates(filter.child())); - filter.getConjuncts().forEach(filterPredicates::remove); - if (!filterPredicates.isEmpty()) { - filterPredicates.addAll(filter.getConjuncts()); - return new LogicalFilter<>(ImmutableSet.copyOf(filterPredicates), filter.child()); - } - return filter; + filterPredicates.removeAll(pullUpAllPredicates(filter.child())); + return new LogicalFilter<>(ImmutableSet.copyOf(filterPredicates), filter.child()); } @Override @@ -156,19 +162,27 @@ private Set getAllExpressions(Plan left, Plan right, Optional baseExpressions = pullUpPredicates(left); baseExpressions.addAll(pullUpPredicates(right)); condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on))); - baseExpressions.addAll(PredicatePropagation.infer(baseExpressions)); - return baseExpressions; + return PredicateInferUtils.inferPredicate(baseExpressions); } private Set pullUpPredicates(Plan plan) { - return Sets.newHashSet(plan.accept(pollUpPredicates, null)); + return Sets.newLinkedHashSet(plan.accept(pullUpPredicates, null)); + } + + private Set pullUpAllPredicates(Plan plan) { + return Sets.newLinkedHashSet(plan.accept(pullUpAllPredicates, null)); } private Plan inferNewPredicate(Plan plan, Set expressions) { - Set predicates = expressions.stream() - .filter(c -> !c.getInputSlots().isEmpty() && plan.getOutputSet().containsAll(c.getInputSlots())) - .collect(Collectors.toSet()); - predicates.removeAll(plan.accept(pollUpPredicates, null)); + Set predicates = new LinkedHashSet<>(); + Set planOutputs = plan.getOutputSet(); + for (Expression expr : expressions) { + Set slots = expr.getInputSlots(); + if (!slots.isEmpty() && planOutputs.containsAll(slots)) { + predicates.add(expr); + } + } + predicates.removeAll(plan.accept(pullUpAllPredicates, null)); return PlanUtils.filterOrSelf(predicates, plan); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java deleted file mode 100644 index d1eba6cce36157..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ /dev/null @@ -1,251 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; -import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.InPredicate; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.DateTimeType; -import org.apache.doris.nereids.types.DateTimeV2Type; -import org.apache.doris.nereids.types.DateType; -import org.apache.doris.nereids.types.DateV2Type; -import org.apache.doris.nereids.types.coercion.CharacterType; -import org.apache.doris.nereids.types.coercion.DateLikeType; -import org.apache.doris.nereids.types.coercion.IntegralType; -import org.apache.doris.nereids.util.ImmutableEqualSet; -import org.apache.doris.nereids.util.TypeCoercionUtils; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * derive additional predicates. - * for example: - * a = b and a = 1 => b = 1 - */ -public class PredicatePropagation { - - private enum InferType { - NONE(null), - INTEGRAL(IntegralType.class), - STRING(CharacterType.class), - DATE(DateLikeType.class), - OTHER(DataType.class); - - private final Class superClazz; - - InferType(Class superClazz) { - this.superClazz = superClazz; - } - } - - /** - * infer additional predicates. - */ - public static Set infer(Set predicates) { - ImmutableEqualSet.Builder equalSetBuilder = new ImmutableEqualSet.Builder<>(); - Map> slotPredicates = new HashMap<>(); - Set> equalPairs = new HashSet<>(); - for (Expression predicate : predicates) { - Set inputSlots = predicate.getInputSlots(); - if (inputSlots.size() == 1) { - if (predicate instanceof ComparisonPredicate - || (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren())) { - slotPredicates.computeIfAbsent(inputSlots.iterator().next(), k -> new ArrayList<>()).add(predicate); - } - continue; - } - - if (predicate instanceof EqualTo) { - getEqualSlot(equalSetBuilder, equalPairs, (EqualTo) predicate); - } - } - - ImmutableEqualSet equalSet = equalSetBuilder.build(); - - Set inferred = new HashSet<>(); - slotPredicates.forEach((left, exprs) -> { - for (Slot right : equalSet.calEqualSet(left)) { - for (Expression expr : exprs) { - Expression inferPredicate = doInferPredicate(left, right, expr); - if (inferPredicate != null) { - inferred.add(inferPredicate); - } - } - } - }); - - // infer equal to equal like a = b & b = c -> a = c - // a b c | e f g - // get (a b) (a c) (b c) | (e f) (e g) (f g) - List> equalSetList = equalSet.calEqualSetList(); - for (Set es : equalSetList) { - List el = es.stream().sorted(Comparator.comparingInt(s -> s.getExprId().asInt())) - .collect(Collectors.toList()); - for (int i = 0; i < el.size(); i++) { - Slot left = el.get(i); - for (int j = i + 1; j < el.size(); j++) { - Slot right = el.get(j); - if (!equalPairs.contains(Pair.of(left, right))) { - inferred.add(TypeCoercionUtils.processComparisonPredicate(new EqualTo(left, right)) - .withInferred(true)); - } - } - } - } - - return inferred; - } - - private static Expression doInferPredicate(Expression equalLeft, Expression equalRight, Expression predicate) { - DataType leftType = predicate.child(0).getDataType(); - InferType inferType; - if (leftType instanceof CharacterType) { - inferType = InferType.STRING; - } else if (leftType instanceof IntegralType) { - inferType = InferType.INTEGRAL; - } else if (leftType instanceof DateLikeType) { - inferType = InferType.DATE; - } else { - inferType = InferType.OTHER; - } - if (predicate instanceof ComparisonPredicate) { - ComparisonPredicate comparisonPredicate = (ComparisonPredicate) predicate; - Optional left = validForInfer(comparisonPredicate.left(), inferType); - Optional right = validForInfer(comparisonPredicate.right(), inferType); - if (!left.isPresent() || !right.isPresent()) { - return null; - } - } else if (predicate instanceof InPredicate) { - InPredicate inPredicate = (InPredicate) predicate; - Optional left = validForInfer(inPredicate.getCompareExpr(), inferType); - if (!left.isPresent()) { - return null; - } - } - - Expression newPredicate = predicate.rewriteUp(e -> { - if (e.equals(equalLeft)) { - return equalRight; - } else if (e.equals(equalRight)) { - return equalLeft; - } else { - return e; - } - }); - if (predicate instanceof ComparisonPredicate) { - return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) newPredicate).withInferred(true); - } else { - return TypeCoercionUtils.processInPredicate((InPredicate) newPredicate).withInferred(true); - } - } - - private static Optional validForInfer(Expression expression, InferType inferType) { - if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) { - return Optional.empty(); - } - if (expression instanceof SlotReference || expression.isConstant()) { - return Optional.of(expression); - } - if (!(expression instanceof Cast)) { - return Optional.empty(); - } - Cast cast = (Cast) expression; - Expression child = cast.child(); - DataType dataType = cast.getDataType(); - DataType childType = child.getDataType(); - if (inferType == InferType.INTEGRAL) { - // avoid cast from wider type to narrower type, such as cast(int as smallint) - // IntegralType dataType = (IntegralType) expression.getDataType(); - // DataType childType = ((Cast) expression).child().getDataType(); - // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) { - // return validForInfer(((Cast) expression).child(), inferType); - // } - return validForInfer(child, inferType); - } else if (inferType == InferType.DATE) { - // avoid lost precision - if (dataType instanceof DateType) { - if (childType instanceof DateV2Type || childType instanceof DateType) { - return validForInfer(child, inferType); - } - } else if (dataType instanceof DateV2Type) { - if (childType instanceof DateType || childType instanceof DateV2Type) { - return validForInfer(child, inferType); - } - } else if (dataType instanceof DateTimeType) { - if (!(childType instanceof DateTimeV2Type)) { - return validForInfer(child, inferType); - } - } else if (dataType instanceof DateTimeV2Type) { - return validForInfer(child, inferType); - } - } else if (inferType == InferType.STRING) { - // avoid substring cast such as cast(char(3) as char(2)) - if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { - return validForInfer(child, inferType); - } - } - return Optional.empty(); - } - - private static Optional> inferInferInfo(ComparisonPredicate comparisonPredicate) { - DataType leftType = comparisonPredicate.left().getDataType(); - InferType inferType; - if (leftType instanceof CharacterType) { - inferType = InferType.STRING; - } else if (leftType instanceof IntegralType) { - inferType = InferType.INTEGRAL; - } else if (leftType instanceof DateLikeType) { - inferType = InferType.DATE; - } else { - inferType = InferType.OTHER; - } - Optional left = validForInfer(comparisonPredicate.left(), inferType); - Optional right = validForInfer(comparisonPredicate.right(), inferType); - if (!left.isPresent() || !right.isPresent()) { - return Optional.empty(); - } - return Optional.of(Pair.of(left.get(), right.get())); - } - - private static void getEqualSlot(ImmutableEqualSet.Builder equalSlots, Set> equalPairs, - EqualTo predicate) { - inferInferInfo(predicate) - .filter(info -> info.first instanceof Slot && info.second instanceof Slot) - .ifPresent(pair -> { - Slot left = (Slot) pair.first; - Slot right = (Slot) pair.second; - equalSlots.addEqualPair(left, right); - equalPairs.add(left.getExprId().asInt() <= right.getExprId().asInt() - ? Pair.of(left, right) : Pair.of(right, left)); - }); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 8082c0624a6047..a6d5cddfd08c61 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -26,7 +26,6 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; -import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalExcept; @@ -38,16 +37,17 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.PredicateInferUtils; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ImmutableSet.Builder; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import java.util.HashMap; -import java.util.HashSet; import java.util.IdentityHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; @@ -60,6 +60,11 @@ public class PullUpPredicates extends PlanVisitor, Void> { Map> cache = new IdentityHashMap<>(); + private final boolean getAllPredicates; + + public PullUpPredicates(boolean all) { + getAllPredicates = all; + } @Override public ImmutableSet visit(Plan plan, Void context) { @@ -71,19 +76,21 @@ public ImmutableSet visit(Plan plan, Void context) { @Override public ImmutableSet visitLogicalOneRowRelation(LogicalOneRowRelation r, Void context) { - ImmutableSet.Builder predicates = ImmutableSet.builder(); - for (NamedExpression expr : r.getProjects()) { - if (expr instanceof Alias && expr.child(0) instanceof Literal) { - predicates.add(new EqualTo(expr.toSlot(), expr.child(0))); + return cacheOrElse(r, () -> { + Set predicates = new LinkedHashSet<>(); + for (NamedExpression expr : r.getProjects()) { + if (expr instanceof Alias && expr.child(0) instanceof Literal) { + predicates.add(new EqualTo(expr.toSlot(), expr.child(0))); + } } - } - return predicates.build(); + return ImmutableSet.copyOf(predicates); + }); } @Override public ImmutableSet visitLogicalIntersect(LogicalIntersect intersect, Void context) { return cacheOrElse(intersect, () -> { - ImmutableSet.Builder builder = ImmutableSet.builder(); + Set predicates = new LinkedHashSet<>(); for (int i = 0; i < intersect.children().size(); ++i) { Plan child = intersect.child(i); Set childFilters = child.accept(this, context); @@ -95,9 +102,9 @@ public ImmutableSet visitLogicalIntersect(LogicalIntersect intersect NamedExpression output = intersect.getOutput().get(j); replaceMap.put(intersect.getRegularChildOutput(i).get(j), output); } - builder.addAll(ExpressionUtils.replace(childFilters, replaceMap)); + predicates.addAll(ExpressionUtils.replace(childFilters, replaceMap)); } - return getAvailableExpressions(builder.build(), intersect); + return getAvailableExpressions(ImmutableSet.copyOf(predicates), intersect); }); } @@ -128,7 +135,7 @@ public ImmutableSet visitLogicalUnion(LogicalUnion union, Void conte } else if (union.getConstantExprsList().isEmpty() && union.arity() != 0) { return getFiltersFromUnionChild(union, context); } else if (!union.getConstantExprsList().isEmpty() && union.arity() != 0) { - HashSet fromChildFilters = new HashSet<>(getFiltersFromUnionChild(union, context)); + Set fromChildFilters = new LinkedHashSet<>(getFiltersFromUnionChild(union, context)); if (fromChildFilters.isEmpty()) { return ImmutableSet.of(); } @@ -153,14 +160,35 @@ public ImmutableSet visitLogicalFilter(LogicalFilter @Override public ImmutableSet visitLogicalJoin(LogicalJoin join, Void context) { return cacheOrElse(join, () -> { - Set predicates = Sets.newHashSet(); - ImmutableSet leftPredicates = join.left().accept(this, context); - ImmutableSet rightPredicates = join.right().accept(this, context); - predicates.addAll(leftPredicates); - predicates.addAll(rightPredicates); - if (join.getJoinType() == JoinType.CROSS_JOIN || join.getJoinType() == JoinType.INNER_JOIN) { - predicates.addAll(join.getHashJoinConjuncts()); - predicates.addAll(join.getOtherJoinConjuncts()); + Set predicates = new LinkedHashSet<>(); + Supplier> leftPredicates = Suppliers.memoize( + () -> join.left().accept(this, context)); + Supplier> rightPredicates = Suppliers.memoize( + () -> join.right().accept(this, context)); + switch (join.getJoinType()) { + case CROSS_JOIN: + case INNER_JOIN: { + predicates.addAll(leftPredicates.get()); + predicates.addAll(rightPredicates.get()); + predicates.addAll(join.getHashJoinConjuncts()); + predicates.addAll(join.getOtherJoinConjuncts()); + break; + } + case LEFT_OUTER_JOIN: + case LEFT_SEMI_JOIN: + case LEFT_ANTI_JOIN: + case NULL_AWARE_LEFT_ANTI_JOIN: { + predicates.addAll(leftPredicates.get()); + break; + } + case RIGHT_OUTER_JOIN: + case RIGHT_SEMI_JOIN: + case RIGHT_ANTI_JOIN: { + predicates.addAll(rightPredicates.get()); + break; + } + default: + break; } return getAvailableExpressions(predicates, join); }); @@ -226,22 +254,21 @@ private ImmutableSet getAvailableExpressions(Set predica if (predicates.isEmpty()) { return ImmutableSet.of(); } - Set inferPredicates = PredicatePropagation.infer(predicates); - Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size() + 10); - Set outputSet = plan.getOutputSet(); - - for (Expression predicate : predicates) { - if (outputSet.containsAll(predicate.getInputSlots())) { - newPredicates.add(predicate); - } + Set inferPredicates = new LinkedHashSet<>(); + if (getAllPredicates) { + inferPredicates.addAll(PredicateInferUtils.inferAllPredicate(predicates)); + } else { + inferPredicates.addAll(PredicateInferUtils.inferPredicate(predicates)); } + Set newPredicates = new LinkedHashSet<>(inferPredicates.size()); + Set outputSet = plan.getOutputSet(); for (Expression inferPredicate : inferPredicates) { if (outputSet.containsAll(inferPredicate.getInputSlots())) { newPredicates.add(inferPredicate); } } - return newPredicates.build(); + return ImmutableSet.copyOf(newPredicates); } private boolean hasAgg(Expression expression) { @@ -249,7 +276,7 @@ private boolean hasAgg(Expression expression) { } private ImmutableSet getFiltersFromUnionChild(LogicalUnion union, Void context) { - Set filters = new HashSet<>(); + Set filters = new LinkedHashSet<>(); for (int i = 0; i < union.getArity(); ++i) { Plan child = union.child(i); Set childFilters = child.accept(this, context); @@ -276,10 +303,10 @@ private ImmutableSet getFiltersFromUnionChild(LogicalUnion union, Vo private ImmutableSet getFiltersFromUnionConstExprs(LogicalUnion union) { List> constExprs = union.getConstantExprsList(); - ImmutableSet.Builder filtersFromConstExprs = ImmutableSet.builder(); + Set filtersFromConstExprs = new LinkedHashSet<>(); for (int col = 0; col < union.getOutput().size(); ++col) { Expression compareExpr = union.getOutput().get(col); - Set options = new HashSet<>(); + Set options = new LinkedHashSet<>(); for (List constExpr : constExprs) { if (constExpr.get(col) instanceof Alias && ((Alias) constExpr.get(col)).child() instanceof Literal) { @@ -296,6 +323,6 @@ private ImmutableSet getFiltersFromUnionConstExprs(LogicalUnion unio filtersFromConstExprs.add(new EqualTo(compareExpr, options.iterator().next())); } } - return filtersFromConstExprs.build(); + return ImmutableSet.copyOf(filtersFromConstExprs); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnequalPredicateInfer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnequalPredicateInfer.java new file mode 100644 index 00000000000000..83209d6691c53e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/UnequalPredicateInfer.java @@ -0,0 +1,576 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.TableIf; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.util.PredicateInferUtils; +import org.apache.doris.nereids.util.TypeCoercionUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * this class do these things: + * {@code + * 1. t1.a=t2.b t2.b=t3.c -> t1.a=t2.b t2.b=t3.c (reserve all three condition) + * 2. remove useless equal predicates(e.g. t1.a=t1.b t1.a=1 t1.b=1 -> t1.a=1 t1.b=1. t1.a=t1.b is removed) + * 3. do unequalPredicateInfer(e.g. t1.a t1.a<1 and t1.a t1.a pair; + private final Relation relation; + + private PairAndRelation(Pair p, Relation r) { + pair = p; + relation = r; + } + } + + // Save and infer the relationship between inputExpressions + private final Relation[][] graph; + // slots or literal at both ends of the input predicate, and its index corresponds to the one in the graph. + private final List usedExprs = new ArrayList<>(); + // predicates used in derivation, this is used in chooseInputPredicates + private final List usedPredicates = new ArrayList<>(); + // usedPredicatesPairs has same length with usedPredicates, + // usedPredicatesPairs[i] and usedPredicates[i] correspond to same predicates + // usedPredicatesPairs is extracted from cast and used in graph + private final List usedPredicatesPairs = new ArrayList<>(); + // Elements and their indexes in usedExprs + private final Map usedExprPosition = new HashMap<>(); + // size of usedExprs + private final int size; + // not use input predicates + private final List otherPredicates = new ArrayList<>(); + + /**Constructor*/ + public InferenceGraph(Set inputs) { + Set inputExpressionSet = new HashSet<>(); + for (Expression input : inputs) { + if (!(input instanceof ComparisonPredicate)) { + otherPredicates.add(input); + continue; + } + ComparisonPredicate comparison = (ComparisonPredicate) input; + if (comparison.left().equals(comparison.right())) { + otherPredicates.add(comparison); + continue; + } + if (comparison.left() instanceof NullLiteral || comparison.right() instanceof NullLiteral) { + otherPredicates.add(comparison); + continue; + } + Set leftSlots = comparison.left().getInputSlots(); + Set rightSlots = comparison.right().getInputSlots(); + if (leftSlots.isEmpty() && rightSlots.isEmpty()) { + otherPredicates.add(comparison); + continue; + } + ComparisonPredicate commute; + if (comparison instanceof LessThan || comparison instanceof LessThanEqual) { + commute = (ComparisonPredicate) comparison.commute().withInferred(comparison.isInferred()); + } else if (comparison instanceof GreaterThan || comparison instanceof GreaterThanEqual + || comparison instanceof EqualTo) { + commute = comparison; + } else { + otherPredicates.add(comparison); + continue; + } + Optional> optionalPair = PredicateInferUtils.getPairFromCast(commute); + if (!optionalPair.isPresent()) { + otherPredicates.add(comparison); + continue; + } + Pair pair = optionalPair.get(); + if (!PredicateInferUtils.isSlotOrLiteral(pair.first) + || !PredicateInferUtils.isSlotOrLiteral(pair.second)) { + otherPredicates.add(comparison); + continue; + } + inputExpressionSet.add(pair.first); + inputExpressionSet.add(pair.second); + usedPredicates.add(comparison); + usedPredicatesPairs.add(new PairAndRelation(pair, getType(commute))); + } + usedExprs.addAll(inputExpressionSet); + // Sorting is required to ensure the stability of the plan shape + // and to ensure that the same results are output in the derivation of d>1 d=c and c>1 d=c + usedExprs.sort(Comparator.comparing(ExpressionTrait::toSql)); + size = usedExprs.size(); + for (int i = 0; i < size; ++i) { + usedExprPosition.put(usedExprs.get(i), i); + } + graph = new Relation[size][size]; + initGraph(graph); + // Add edges to the graph. + for (PairAndRelation predicatesPair : usedPredicatesPairs) { + int l = usedExprPosition.get(predicatesPair.pair.first); + int r = usedExprPosition.get(predicatesPair.pair.second); + set(graph, l, r, predicatesPair.relation); + } + } + + public void initGraph(Relation[][] g) { + for (int i = 0; i < size; ++i) { + for (int j = 0; j < size; ++j) { + g[i][j] = Relation.UNDEFINED; + } + } + } + + private void connect(Relation[][] graph, int left, int right, int mid) { + if (graph[left][right] != Relation.EQ) { + if (graph[left][mid] == Relation.EQ && graph[mid][right] == Relation.EQ) { + graph[left][right] = Relation.EQ; + } + } + if (graph[left][right] != Relation.GTE) { + if (graph[left][mid] == Relation.GTE && graph[mid][right] == Relation.EQ + || graph[left][mid] == Relation.EQ && graph[mid][right] == Relation.GTE) { + graph[left][right] = Relation.GTE; + } + } + if (graph[left][right] != Relation.GT) { + if (graph[left][mid] == Relation.GT && graph[mid][right] != Relation.UNDEFINED + || graph[left][mid] != Relation.UNDEFINED && graph[mid][right] == Relation.GT) { + graph[left][right] = Relation.GT; + } + } + } + + // Calculate the relationship between left and right derived from mid + private Relation connectInThisPath(final Relation[][] graph, int left, int right, int mid) { + Relation deduceRelation = Relation.UNDEFINED; + if (graph[left][mid] == Relation.EQ && graph[mid][right] == Relation.EQ) { + deduceRelation = Relation.EQ; + } + if (graph[left][mid] == Relation.GTE && graph[mid][right] == Relation.EQ + || graph[left][mid] == Relation.EQ && graph[mid][right] == Relation.GTE) { + deduceRelation = Relation.GTE; + } + if (graph[left][mid] == Relation.GT && graph[mid][right] != Relation.UNDEFINED + || graph[left][mid] != Relation.UNDEFINED && graph[mid][right] == Relation.GT) { + deduceRelation = Relation.GT; + } + return deduceRelation; + } + + /** use Floyd algorithm to deduce the inequality */ + public void deduce(Relation[][] graph) { + for (int mid = 0; mid < size; ++mid) { + for (int left = 0; left < size; ++left) { + for (int right = 0; right < size; ++right) { + connect(graph, left, right, mid); + } + } + } + } + + /**topoSort*/ + public List topoSort() { + ArrayList order = new ArrayList<>(); + order.ensureCapacity(size); + ArrayList visited = new ArrayList<>(); + visited.ensureCapacity(size); + for (int i = 0; i < size; ++i) { + visited.add(false); + } + for (int i = 0; i < size; ++i) { + dfs(i, visited, order); + } + return order; + } + + private void dfs(int node, List visited, List order) { + if (visited.get(node)) { + return; + } + visited.set(node, true); + for (int i = 0; i < size; ++i) { + if (graph[node][i] == Relation.GT || graph[node][i] == Relation.GTE) { + dfs(i, visited, order); + } + } + order.add(node); + } + + /**Determine whether the slots in a predicate come from only one table*/ + private boolean isTableFilter(int left, int right) { + Set qualifiers = new HashSet<>(); + for (Slot slot : usedExprs.get(left).getInputSlots()) { + qualifiers.add(String.join(".", slot.getQualifier())); + } + for (Slot slot : usedExprs.get(right).getInputSlots()) { + qualifiers.add(String.join(".", slot.getQualifier())); + } + // TODO: + // isTableFilter(abs(t1.a)#1 = abs(t1.b)#2) will return true + // isTableFilter(abs(t1.a)#1 = abs(t2.b)#2) will also return true, which is wrong. + // because expr(e.g. abs(a) #1) qualifiers is empty. + // We cannot distinguish whether abs(t1.a)#1 = abs(t2.b)#2 is a TableFilter or not. + // current code may lead to some useful predicates be removed + return qualifiers.size() == 1; + } + + private boolean hasIndexOrPartitionColumn(Expression left, Expression right) { + SlotReference checkSlot; + if (left instanceof SlotReference && right instanceof Literal) { + checkSlot = (SlotReference) left; + } else if (left instanceof Literal && right instanceof SlotReference) { + checkSlot = (SlotReference) right; + } else { + return false; + } + if (!checkSlot.isColumnFromTable()) { + return false; + } + Column column = checkSlot.getColumn().get(); + if (column.isKey()) { + return true; + } + if (!checkSlot.getTable().isPresent()) { + return false; + } + TableIf tableIf = checkSlot.getTable().get(); + if (tableIf.isPartitionedTable() && tableIf.isPartitionColumn(column.getName())) { + return true; + } + /* Indexes are seldom used and are not supported temporarily + if (tableIf.getType() != TableType.OLAP) { + return false; + } + TableIndexes tableIndexes = tableIf.getTableIndexes(); + for (Index index : tableIndexes.getIndexes()) { + IndexDef.IndexType type = index.getIndexType(); + if (type == IndexType.NGRAM_BF || type == IndexType.BLOOMFILTER) { + continue; + } + Set columns = new HashSet<>(index.getColumns()); + if (columns.contains(column.getName())) { + return true; + } + }*/ + return false; + } + + // determine whether the comparison predicate of type between left right can be deduced by mid + private boolean checkDeducible(final Relation[][] graph, int left, int right, int mid, Relation type) { + Relation deduceType = connectInThisPath(graph, left, right, mid); + return deduceType == type; + } + + private List removeExprEqualToConstant(List order, Set equalWithConstant) { + // Remove expr equal to constant + List orderToInfer = new ArrayList<>(); + for (Integer integer : order) { + if (equalWithConstant.contains(integer)) { + continue; + } + orderToInfer.add(integer); + } + return orderToInfer; + } + + /**chooseUnequalPredicates*/ + public void chooseUnequalPredicates(Relation[][] chosen, Set equalWithConstant) { + List order = topoSort(); + List orderToInfer = removeExprEqualToConstant(order, equalWithConstant); + //Select predicate: + // 1. Do not select predicates that can be deduced from the intermediate expr + // 2. If it is an index column or partition column, reserve the predicate + for (int i = 1; i < orderToInfer.size(); ++i) { + for (int j = 0; j < i; ++j) { + int left = orderToInfer.get(i); + int right = orderToInfer.get(j); + if (graph[left][right] == Relation.EQ || graph[left][right] == Relation.UNDEFINED) { + continue; + } + if (!isTableFilter(left, right)) { + continue; + } + boolean skip = hasIndexOrPartitionColumn(usedExprs.get(left), usedExprs.get(right)); + boolean deducible = false; + for (int m = j + 1; !skip && !deducible && m < i; ++m) { + int mid = orderToInfer.get(m); + if (usedExprs.get(mid) instanceof Literal) { + deducible = checkDeducible(graph, left, right, mid, graph[left][right]); + } else if (isTableFilter(left, mid) && isTableFilter(right, mid)) { + deducible = checkDeducible(graph, left, right, mid, graph[left][right]); + } + } + if (!deducible) { + set(chosen, left, right, graph[left][right]); + } + } + } + } + + private Set generatePredicates(Relation[][] chosen) { + Set newPredicates = new LinkedHashSet<>(); + for (int i = 0; i < size; ++i) { + for (int j = 0; j < size; ++j) { + if (i == j || isAllLiteral(i, j)) { + continue; + } + try { + if (chosen[i][j] == Relation.GT) { + newPredicates.add(normalize(new GreaterThan(usedExprs.get(i), usedExprs.get(j)))); + } else if (chosen[i][j] == Relation.GTE) { + newPredicates.add(normalize(new GreaterThanEqual(usedExprs.get(i), usedExprs.get(j)))); + } else if (chosen[i][j] == Relation.EQ) { + newPredicates.add(normalize(new EqualTo(usedExprs.get(i), usedExprs.get(j)))); + clear(chosen, i, j, Relation.EQ); + } + } catch (AnalysisException e) { + // type error, just not generate this predicate, do nothing but continue + } + } + } + return newPredicates; + } + + private ComparisonPredicate normalizePredicate(ComparisonPredicate expr) { + return expr.left().isConstant() && !expr.right().isConstant() ? expr.commute() : expr; + } + + private Relation getType(ComparisonPredicate comparisonPredicate) { + if (comparisonPredicate instanceof GreaterThan) { + return Relation.GT; + } else if (comparisonPredicate instanceof GreaterThanEqual) { + return Relation.GTE; + } else if (comparisonPredicate instanceof EqualTo) { + return Relation.EQ; + } + return Relation.UNDEFINED; + } + + private void clear(Relation[][] graph, int left, int right, Relation type) { + graph[left][right] = Relation.UNDEFINED; + if (type == Relation.EQ) { + graph[right][left] = Relation.UNDEFINED; + } + } + + private void set(Relation[][] graph, int left, int right, Relation type) { + graph[left][right] = type; + if (type == Relation.EQ) { + graph[right][left] = type; + } + } + + // A new edge from hub1 to hub2 has been added to the graph. + // Use this edge to extend the connectivity between the graph nodes + private void expandGraph(Relation[][] graph, int hub1, int hub2) { + //Update the path from all nodes to hub2 (use hub1->hub2) + for (int left = 0; left < size; ++left) { + connect(graph, left, hub2, hub1); + } + // Use hub2 as the transit node to update the path between any two nodes + for (int l = 0; l < size; ++l) { + for (int r = 0; r < size; ++r) { + connect(graph, l, r, hub2); + } + } + } + + /**chooseInputPredicates*/ + public Set chooseInputPredicates(Relation[][] chosen) { + boolean[] keep = new boolean[usedPredicates.size()]; + Relation[][] deduced = new Relation[size][size]; + for (int i = 0; i < size; ++i) { + for (int j = 0; j < size; ++j) { + deduced[i][j] = chosen[i][j]; + if (i == j) { + deduced[i][j] = Relation.EQ; + } + } + } + deduce(deduced); + // If an input predicate is not chosen and can be deduced by chosen, + // then the input predicate need not be retained (because it is a useless predicate) + // And the predicates in inputs that cannot be deduced by chosen should be retained. + for (int i = 0; i < usedPredicates.size(); ++i) { + Relation type = usedPredicatesPairs.get(i).relation; + int left = usedExprPosition.get(usedPredicatesPairs.get(i).pair.first); + int right = usedExprPosition.get(usedPredicatesPairs.get(i).pair.second); + if (chosen[left][right] == type) { + keep[i] = true; + clear(chosen, left, right, type); + } else if (deduced[left][right] != type) { + keep[i] = true; + set(deduced, left, right, Relation.EQ); + expandGraph(deduced, left, right); + if (type == Relation.EQ) { + expandGraph(deduced, right, left); + } + } + } + Set chooseInputs = new LinkedHashSet<>(); + for (int i = 0; i < usedPredicates.size(); ++i) { + if (!keep[i]) { + continue; + } + chooseInputs.add(normalizePredicate(usedPredicates.get(i)) + .withInferred(usedPredicates.get(i).isInferred())); + } + return chooseInputs; + } + + /**chooseEqualPredicates*/ + public Relation[][] chooseEqualPredicates(Set equalWithConstant) { + Relation[][] chosen = new Relation[size][size]; + initGraph(chosen); + int[] equalToLiteral = new int[size]; + Arrays.fill(equalToLiteral, -1); + // save equal predicates like a=b (no literal) + List> tableFilters = new ArrayList<>(); + // save equal predicates like t1.a=t2.b (no literal) + List> nonTableFilters = new ArrayList<>(); + for (int i = 0; i < size; ++i) { + for (int j = i + 1; j < size; ++j) { + if (graph[i][j] != Relation.EQ) { + continue; + } + // choose predicate with one side literal or t1.a=t2.b(not table filter equal) + if (usedExprs.get(i) instanceof Literal && usedExprs.get(j) instanceof Literal) { + continue; + } else if (!(usedExprs.get(i) instanceof Literal) && !(usedExprs.get(j) instanceof Literal)) { + if (isTableFilter(i, j)) { + tableFilters.add(Pair.of(i, j)); + } else { + nonTableFilters.add(Pair.of(i, j)); + } + } else if (usedExprs.get(i) instanceof Literal + || usedExprs.get(j) instanceof Literal) { + set(chosen, i, j, Relation.EQ); + if (usedExprs.get(i) instanceof Literal) { + equalToLiteral[j] = i; + equalWithConstant.add(j); + } else { + equalToLiteral[i] = j; + equalWithConstant.add(i); + } + } + } + } + // a=b a=c a=1 only infer a=1 b=1 c=1, not retain a=b a=c + for (Pair tableFilter : tableFilters) { + int left = tableFilter.first; + int right = tableFilter.second; + if (equalToLiteral[left] == -1 || equalToLiteral[right] == -1) { + set(chosen, left, right, Relation.EQ); + equalToLiteral[left] = left; + equalToLiteral[right] = left; + } + } + for (Pair nonTableFilter : nonTableFilters) { + int left = nonTableFilter.first; + int right = nonTableFilter.second; + if (!equalWithConstant.contains(left) && !equalWithConstant.contains(right)) { + set(chosen, left, right, Relation.EQ); + } + } + return chosen; + } + + private Expression normalize(ComparisonPredicate cmp) { + return TypeCoercionUtils.processComparisonPredicate(normalizePredicate(cmp)).withInferred(true); + } + + private boolean isAllLiteral(int i, int j) { + Expression left = usedExprs.get(i); + Expression right = usedExprs.get(j); + return left instanceof Literal && right instanceof Literal; + } + + /** for test */ + public Relation[][] getGraph() { + return graph; + } + } + + /**inferUnequalPredicates*/ + public static Set inferUnequalPredicates(Set inputs) { + if (inputs.size() < 2) { + return inputs; + } + InferenceGraph inferGraph = new InferenceGraph(inputs); + if (inferGraph.usedExprs.isEmpty()) { + return inputs; + } + inferGraph.deduce(inferGraph.graph); + Set equalWithConstant = new HashSet<>(); + InferenceGraph.Relation[][] chosen = inferGraph.chooseEqualPredicates(equalWithConstant); + inferGraph.chooseUnequalPredicates(chosen, equalWithConstant); + Set newPredicates = inferGraph.chooseInputPredicates(chosen); + newPredicates.addAll(inferGraph.generatePredicates(chosen)); + newPredicates.addAll(inferGraph.otherPredicates); + return newPredicates; + } + + /** deduce predicates and generate all predicates without choosing*/ + public static Set inferAllPredicates(Set inputs) { + if (inputs.size() < 2) { + return inputs; + } + InferenceGraph inferGraph = new InferenceGraph(inputs); + if (inferGraph.usedExprs.isEmpty()) { + return inputs; + } + inferGraph.deduce(inferGraph.graph); + Set newPredicates = new LinkedHashSet<>(); + newPredicates.addAll(inferGraph.generatePredicates(inferGraph.graph)); + newPredicates.addAll(inferGraph.otherPredicates); + return newPredicates; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Like.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Like.java index 89a9c7797152d6..84b6ffa984fff4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Like.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Like.java @@ -28,13 +28,16 @@ * like expression: a like 'xxx%'. */ public class Like extends StringRegexPredicate { - public Like(Expression left, Expression right) { - super("like", ImmutableList.of(left, right)); + this(ImmutableList.of(left, right)); } private Like(List children) { - super("like", children); + this(children, false); + } + + private Like(List children, boolean inferred) { + super("like", children, inferred); } @Override @@ -46,4 +49,9 @@ public Like withChildren(List children) { public R accept(ExpressionVisitor visitor, C context) { return visitor.visitLike(this, context); } + + @Override + public Expression withInferred(boolean inferred) { + return new Like(this.children, inferred); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java index 44197ae617d276..5061cab5ac9631 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java @@ -44,6 +44,11 @@ public Not(Expression child) { this(child, false); } + public Not(List child, boolean isGeneratedIsNotNull, boolean inferred) { + super(child, inferred); + this.isGeneratedIsNotNull = isGeneratedIsNotNull; + } + public Not(Expression child, boolean isGeneratedIsNotNull) { super(ImmutableList.of(child)); this.isGeneratedIsNotNull = isGeneratedIsNotNull; @@ -115,4 +120,9 @@ public Not withGeneratedIsNotNull(boolean isGeneratedIsNotNull) { public List expectedInputTypes() { return EXPECTS_INPUT_TYPES; } + + @Override + public Expression withInferred(boolean inferred) { + return new Not(this.children, false, inferred); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StringRegexPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StringRegexPredicate.java index 4d31f200cd9577..8900ac928590c3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StringRegexPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/StringRegexPredicate.java @@ -42,7 +42,11 @@ public abstract class StringRegexPredicate extends ScalarFunction ); protected StringRegexPredicate(String name, List children) { - super(name, children); + this(name, children, false); + } + + protected StringRegexPredicate(String name, List children, boolean inferred) { + super(name, children, inferred); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java index c0f4ddc44044ac..5ccc64a34bb43b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java @@ -50,7 +50,11 @@ public BoundFunction(String name, Expression... arguments) { } public BoundFunction(String name, List children) { - super(name, children); + this(name, children, false); + } + + public BoundFunction(String name, List children, boolean inferred) { + super(name, children, inferred); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java index 9e4c19365d837f..d8cb79b6ef422a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Function.java @@ -35,7 +35,11 @@ public Function(String name, Expression... children) { } public Function(String name, List children) { - super(children); + this(name, children, false); + } + + public Function(String name, List children, boolean inferred) { + super(children, inferred); this.name = Objects.requireNonNull(name, "name can not be null"); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ScalarFunction.java index 7267ecc8997be0..97c0e851db66d3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ScalarFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ScalarFunction.java @@ -33,7 +33,11 @@ public ScalarFunction(String name, Expression... arguments) { } public ScalarFunction(String name, List arguments) { - super(name, arguments); + this(name, arguments, false); + } + + public ScalarFunction(String name, List arguments, boolean inferred) { + super(name, arguments, inferred); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PredicateInferUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PredicateInferUtils.java new file mode 100644 index 00000000000000..ab840848a812d8 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PredicateInferUtils.java @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.rewrite.InferPredicateByReplace; +import org.apache.doris.nereids.rules.rewrite.UnequalPredicateInfer; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DateTimeType; +import org.apache.doris.nereids.types.DateTimeV2Type; +import org.apache.doris.nereids.types.DateType; +import org.apache.doris.nereids.types.DateV2Type; +import org.apache.doris.nereids.types.coercion.CharacterType; +import org.apache.doris.nereids.types.coercion.DateLikeType; +import org.apache.doris.nereids.types.coercion.IntegralType; + +import java.util.LinkedHashSet; +import java.util.Optional; +import java.util.Set; + +/** PredicateInferUtils */ +public class PredicateInferUtils { + private enum InferType { + NONE(null), + INTEGRAL(IntegralType.class), + STRING(CharacterType.class), + DATE(DateLikeType.class), + OTHER(DataType.class); + + private final Class superClazz; + + InferType(Class superClazz) { + this.superClazz = superClazz; + } + } + + public static boolean isSlotOrLiteral(Expression expr) { + return expr instanceof SlotReference || expr instanceof Literal; + } + + /**The inputs predicate is divided into two parts. One is the predicate directly reserved, which does not enter + * the non equivalent derivation, and the other is the predicates entering the non equivalent derivation*/ + public static void getComplexAndSimplePredicates(Set inputs, Set complex, + Set simple) { + for (Expression input : inputs) { + if (input instanceof GreaterThan || input instanceof GreaterThanEqual + || input instanceof EqualTo || input instanceof LessThan + || input instanceof LessThanEqual) { + simple.add((ComparisonPredicate) input); + } else { + complex.add(input); + } + } + } + + /**The predicate derivation is based on the input predicate predicates, which is divided into two parts. + * The equivalent relation used in ReplacePredicate and calculated by union-find derive like, in, not + * and ComparisonPredicate; + * The NonEqualPredicateInfer class deduces predicates based on non-equal relations, and deletes + * the useless ComparisonPredicates derived from ReplacePredicate*/ + public static Set inferPredicate(Set predicates) { + if (predicates.size() < 2) { + return predicates; + } + Set inferAndOriginPredicates = InferPredicateByReplace.infer(predicates); + Set inferPredicates = new LinkedHashSet<>( + UnequalPredicateInfer.inferUnequalPredicates(inferAndOriginPredicates)); + // Keep the order of predicates. The input predicates are in the front + // and the derived predicates are in the rear + Set res = new LinkedHashSet<>(); + for (Expression pred : predicates) { + if (inferPredicates.contains(pred)) { + res.add(pred); + inferPredicates.remove(pred); + } + } + res.addAll(inferPredicates); + return res; + } + + /** get all predicates(with redundant predicates), e.g. b>1 a>b -> a>1 a>b b>1*/ + public static Set inferAllPredicate(Set predicates) { + if (predicates.size() < 2) { + return predicates; + } + Set inferAndOriginPredicates = InferPredicateByReplace.infer(predicates); + return new LinkedHashSet<>(UnequalPredicateInfer.inferAllPredicates(inferAndOriginPredicates)); + } + + /**getPairFromCast*/ + public static Optional> getPairFromCast(ComparisonPredicate comparisonPredicate) { + DataType leftType = comparisonPredicate.left().getDataType(); + InferType inferType; + if (leftType instanceof CharacterType) { + inferType = InferType.STRING; + } else if (leftType instanceof IntegralType) { + inferType = InferType.INTEGRAL; + } else if (leftType instanceof DateLikeType) { + inferType = InferType.DATE; + } else { + inferType = InferType.OTHER; + } + Optional left = validForInfer(comparisonPredicate.left(), inferType); + Optional right = validForInfer(comparisonPredicate.right(), inferType); + if (!left.isPresent() || !right.isPresent()) { + return Optional.empty(); + } + return Optional.of(Pair.of(left.get(), right.get())); + } + + private static Optional validForInfer(Expression expression, InferType inferType) { + if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) { + return Optional.empty(); + } + if (!(expression instanceof Cast)) { + return Optional.of(expression); + } + Cast cast = (Cast) expression; + Expression child = cast.child(); + DataType dataType = cast.getDataType(); + DataType childType = child.getDataType(); + if (inferType == InferType.INTEGRAL) { + if (dataType instanceof IntegralType) { + IntegralType integralType = (IntegralType) dataType; + if (childType instanceof IntegralType && integralType.widerThan((IntegralType) childType)) { + return validForInfer(((Cast) expression).child(), inferType); + } + } + } else if (inferType == InferType.DATE) { + // avoid lost precision + if (dataType instanceof DateType) { + if (childType instanceof DateV2Type || childType instanceof DateType) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateV2Type) { + if (childType instanceof DateType || childType instanceof DateV2Type) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateTimeType) { + if (!(childType instanceof DateTimeV2Type)) { + return validForInfer(child, inferType); + } + } else if (dataType instanceof DateTimeV2Type) { + return validForInfer(child, inferType); + } + } else if (inferType == InferType.STRING) { + // avoid substring cast such as cast(char(3) as char(2)) + if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { + return validForInfer(child, inferType); + } + } + return Optional.empty(); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniformTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniformTest.java index ce9fe85942e67d..8460425a32a623 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniformTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniformTest.java @@ -209,6 +209,10 @@ void testWindow() { @Test void testEqual() { + // Because in INFER_PREDICATES, id=1 and id=id2 is rewritten as id=1 and id2=1 + // The equivalence set in DataTrait does not support the id=1 id2=1->id=id2 temporarily, + // so in order to run through this case, Disable INFER_PREDICATES temporarily + connectContext.getSessionVariable().setDisableNereidsRules("INFER_PREDICATES,PRUNE_EMPTY_PARTITION"); Plan plan = PlanChecker.from(connectContext) .analyze("select id2 from agg where id = 1 and id = id2") .rewrite() diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java new file mode 100644 index 00000000000000..98fbbfbec13f2e --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferPredicateByReplaceTest.java @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.Like; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs; +import org.apache.doris.nereids.trees.expressions.functions.scalar.DateTrunc; +import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DateTimeV2Type; +import org.apache.doris.nereids.types.DateType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.HashSet; +import java.util.Set; + +public class InferPredicateByReplaceTest { + @Test + public void testInferWithEqualTo() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE); + EqualTo equalTo = new EqualTo(a, b); + Set inputs = new HashSet<>(); + inputs.add(equalTo); + + Set result = InferPredicateByReplace.infer(inputs); + Assertions.assertEquals(1, result.size(), "Expected no additional predicates."); + } + + @Test + public void testInferWithInPredicate() { + // abs(a) IN (1, 2, 3) + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE); + InPredicate inPredicate = new InPredicate(new Abs(a), + ImmutableList.of(new IntegerLiteral(1), new IntegerLiteral(2), new IntegerLiteral(3))); + EqualTo equalTo = new EqualTo(a, b); + Set inputs = new HashSet<>(); + inputs.add(inPredicate); + inputs.add(equalTo); + + Set result = InferPredicateByReplace.infer(inputs); + Assertions.assertEquals(3, result.size()); + } + + @Test + public void testInferWithInPredicateNotSupport() { + // a IN (1, b) + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE); + InPredicate inPredicate = new InPredicate(a, + ImmutableList.of(new IntegerLiteral(1), b)); + EqualTo equalTo = new EqualTo(a, b); + Set inputs = new HashSet<>(); + inputs.add(inPredicate); + inputs.add(equalTo); + + Set result = InferPredicateByReplace.infer(inputs); + Assertions.assertEquals(2, result.size()); + } + + @Test + public void testInferWithNotPredicate() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE); + InPredicate inPredicate = new InPredicate(a, ImmutableList.of(new IntegerLiteral(1), new IntegerLiteral(2))); + Not notPredicate = new Not(inPredicate); + EqualTo equalTo = new EqualTo(a, b); + Set inputs = new HashSet<>(); + inputs.add(notPredicate); + inputs.add(equalTo); + + Set result = InferPredicateByReplace.infer(inputs); + Not expected = new Not(new InPredicate(b, ImmutableList.of(new IntegerLiteral(1), new IntegerLiteral(2)))); + Assertions.assertTrue(result.contains(expected)); + } + + @Test + public void testInferWithLikePredicate() { + // a LIKE 'test%' + SlotReference a = new SlotReference("a", StringType.INSTANCE); + SlotReference b = new SlotReference("b", StringType.INSTANCE); + EqualTo equalTo = new EqualTo(a, b); + Like like = new Like(a, new StringLiteral("test%")); + Set inputs = new HashSet<>(); + inputs.add(like); + inputs.add(equalTo); + + Set result = InferPredicateByReplace.infer(inputs); + Like expected = new Like(b, new StringLiteral("test%")); + Assertions.assertEquals(3, result.size()); + Assertions.assertTrue(result.contains(expected), "Expected to find b like 'test%' in the result"); + } + + @Test + public void testInferWithLikePredicateNotSupport() { + // a LIKE b + SlotReference a = new SlotReference("a", StringType.INSTANCE); + SlotReference b = new SlotReference("b", StringType.INSTANCE); + EqualTo equalTo = new EqualTo(a, b); + Like like = new Like(a, b); + Set inputs = new HashSet<>(); + inputs.add(like); + inputs.add(equalTo); + + Set result = InferPredicateByReplace.infer(inputs); + Assertions.assertEquals(2, result.size()); + } + + @Test + public void testInferWithOrPredicate() { + SlotReference a = new SlotReference("a", DateTimeV2Type.SYSTEM_DEFAULT); + SlotReference b = new SlotReference("b", DateTimeV2Type.SYSTEM_DEFAULT); + EqualTo equalTo = new EqualTo(a, b); + Or or = new Or(new GreaterThan(a, new DateTimeV2Literal("2022-02-01 10:00:00")), + new LessThan(a, new DateTimeV2Literal("2022-01-01 10:00:00"))); + Set inputs = new HashSet<>(); + inputs.add(or); + inputs.add(equalTo); + + Set result = InferPredicateByReplace.infer(inputs); + Assertions.assertEquals(2, result.size()); + } + + @Test + public void testInferWithPredicateDateTrunc() { + SlotReference a = new SlotReference("a", DateTimeV2Type.SYSTEM_DEFAULT); + SlotReference b = new SlotReference("b", DateTimeV2Type.SYSTEM_DEFAULT); + EqualTo equalTo = new EqualTo(a, b); + GreaterThan greaterThan = new GreaterThan(new DateTrunc(a, new VarcharLiteral("year")), new DateTimeV2Literal("2022-02-01 10:00:00")); + Set inputs = new HashSet<>(); + inputs.add(greaterThan); + inputs.add(equalTo); + + Set result = InferPredicateByReplace.infer(inputs); + Assertions.assertEquals(3, result.size()); + } + + @Test + public void testValidForInfer() { + SlotReference a = new SlotReference("a", TinyIntType.INSTANCE); + Cast castExprA = new Cast(a, IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", BigIntType.INSTANCE); + Cast castExprB = new Cast(b, IntegerType.INSTANCE); + SlotReference c = new SlotReference("c", DateType.INSTANCE); + Cast castExprC = new Cast(c, IntegerType.INSTANCE); + + EqualTo equalTo1 = new EqualTo(castExprA, castExprB); + EqualTo equalTo2 = new EqualTo(castExprA, castExprC); + Set inputs = new HashSet<>(); + inputs.add(equalTo1); + inputs.add(equalTo2); + Assertions.assertEquals(2, InferPredicateByReplace.infer(inputs).size()); + } + + @Test + public void testNotInferWithTransitiveEqualitySameTable() { + // a = b, b = c + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + EqualTo equalTo1 = new EqualTo(a, b); + EqualTo equalTo2 = new EqualTo(b, c); + Set inputs = new HashSet<>(); + inputs.add(equalTo1); + inputs.add(equalTo2); + Set result = InferPredicateByReplace.infer(inputs); + Assertions.assertEquals(2, result.size()); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java deleted file mode 100644 index 1efa94451af6dd..00000000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java +++ /dev/null @@ -1,67 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.GreaterThan; -import org.apache.doris.nereids.trees.expressions.InPredicate; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.types.BigIntType; -import org.apache.doris.nereids.types.SmallIntType; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import org.junit.jupiter.api.Test; - -import java.util.Set; - -class PredicatePropagationTest { - private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE); - private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE); - private final SlotReference c = new SlotReference("c", BigIntType.INSTANCE); - - @Test - void equal() { - Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, Literal.of(1))); - Set inferExprs = PredicatePropagation.infer(exprs); - System.out.println(inferExprs); - } - - @Test - void in() { - Set exprs = ImmutableSet.of(new EqualTo(a, b), new InPredicate(a, ImmutableList.of(Literal.of(1)))); - Set inferExprs = PredicatePropagation.infer(exprs); - System.out.println(inferExprs); - } - - @Test - void inferSlotEqual() { - Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, c)); - Set inferExprs = PredicatePropagation.infer(exprs); - System.out.println(inferExprs); - } - - @Test - void inferComplex0() { - Set exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, c), new GreaterThan(a, Literal.of(1))); - Set inferExprs = PredicatePropagation.infer(exprs); - System.out.println(inferExprs); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/UnequalPredicateInferTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/UnequalPredicateInferTest.java new file mode 100644 index 00000000000000..7bd43c98929bc2 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/UnequalPredicateInferTest.java @@ -0,0 +1,688 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.rewrite.UnequalPredicateInfer.InferenceGraph; +import org.apache.doris.nereids.rules.rewrite.UnequalPredicateInfer.InferenceGraph.Relation; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.types.DateTimeType; +import org.apache.doris.nereids.types.DateType; +import org.apache.doris.nereids.types.DateV2Type; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.PredicateInferUtils; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +public class UnequalPredicateInferTest { + @Test + public void testInferWithTransitiveEqualitySameTable() { + // t1.a = t1.b, t1.b = t1.c only output 2 predicates + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + EqualTo equalTo1 = new EqualTo(a, b); + EqualTo equalTo2 = new EqualTo(b, c); + Set inputs = new LinkedHashSet<>(); + inputs.add(equalTo1); + inputs.add(equalTo2); + Set result = UnequalPredicateInfer.inferUnequalPredicates(inputs); + EqualTo expected1 = new EqualTo(a, b); + EqualTo expected2 = new EqualTo(a, c); + Assertions.assertEquals(2, result.size()); + Assertions.assertTrue(result.contains(expected1) && result.contains(expected2)); + } + + @Test + public void testTopoSort() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + // a b c has index 0 1 2 (sort by toSql()) + // a>b b>c + ComparisonPredicate gt1 = new GreaterThan(a, b); + ComparisonPredicate gt2 = new GreaterThan(b, c); + Set inputs = new LinkedHashSet<>(); + inputs.add(gt1); + inputs.add(gt2); + UnequalPredicateInfer.InferenceGraph inferenceGraph = new UnequalPredicateInfer.InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + List res = inferenceGraph.topoSort(); + // list(2,1,0) means order c b a + List expected = Arrays.asList(2, 1, 0); + Assertions.assertEquals(expected, res); + // a>=b b>=c + ComparisonPredicate gte1 = new GreaterThanEqual(a, b); + ComparisonPredicate gte2 = new GreaterThanEqual(b, c); + Set inputs2 = new LinkedHashSet<>(); + inputs2.add(gte1); + inputs2.add(gte2); + UnequalPredicateInfer.InferenceGraph inferenceGraph2 = new UnequalPredicateInfer.InferenceGraph(inputs2); + inferenceGraph2.deduce(inferenceGraph2.getGraph()); + List res2 = inferenceGraph2.topoSort(); + List expected2 = Arrays.asList(2, 1, 0); + Assertions.assertEquals(expected2, res2); + // a<=b b<=c + ComparisonPredicate lte1 = new LessThanEqual(a, b); + ComparisonPredicate lte2 = new LessThanEqual(b, c); + Set inputs3 = new LinkedHashSet<>(); + inputs3.add(lte1); + inputs3.add(lte2); + UnequalPredicateInfer.InferenceGraph inferenceGraph3 = new UnequalPredicateInfer.InferenceGraph(inputs3); + inferenceGraph3.deduce(inferenceGraph3.getGraph()); + List res3 = inferenceGraph3.topoSort(); + List expected3 = Arrays.asList(0, 1, 2); + Assertions.assertEquals(expected3, res3); + // a<=b b inputs4 = new LinkedHashSet<>(); + inputs4.add(lte3); + inputs4.add(gt3); + UnequalPredicateInfer.InferenceGraph inferenceGraph4 = new UnequalPredicateInfer.InferenceGraph(inputs4); + inferenceGraph4.deduce(inferenceGraph4.getGraph()); + List res4 = inferenceGraph4.topoSort(); + List expected4 = Arrays.asList(0, 1, 2); + Assertions.assertEquals(expected4, res4); + } + + @Test + public void testTopoSortWithEqual() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + // a=b b>c + ComparisonPredicate gt1 = new EqualTo(a, b); + ComparisonPredicate gt2 = new GreaterThan(b, c); + Set inputs = new LinkedHashSet<>(); + inputs.add(gt1); + inputs.add(gt2); + UnequalPredicateInfer.InferenceGraph inferenceGraph = new UnequalPredicateInfer.InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + List res = inferenceGraph.topoSort(); + // order is c a b + List expected = Arrays.asList(2, 0, 1); + Assertions.assertEquals(expected, res); + } + + @Test + public void testTopoSortWithEqualMulti() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal d = new IntegerLiteral(1); + // a=b b>c 1 inputs = new LinkedHashSet<>(); + inputs.add(eq); + inputs.add(gt); + inputs.add(lte); + UnequalPredicateInfer.InferenceGraph inferenceGraph = new UnequalPredicateInfer.InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + List res = inferenceGraph.topoSort(); + // order is 1 c a b + List expected = Arrays.asList(0, 3, 1, 2); + Assertions.assertEquals(expected, res); + } + + public void initGraph(Relation[][] g, int size) { + for (int i = 0; i < size; ++i) { + for (int j = 0; j < size; ++j) { + g[i][j] = Relation.UNDEFINED; + } + } + } + + public static void assert2DArrayEquals(Relation[][] expected, Relation[][] actual) { + for (int i = 0; i < expected.length; i++) { + Assertions.assertArrayEquals(expected[i], actual[i], "Row " + i + " is not equal"); + } + } + + // t1.a = 1, t1.b = 1 -> t1.a = 1, t1.b = 1 + @Test + public void testChooseEqualPredicatesSameTable1() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal d = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, d); + ComparisonPredicate eq2 = new EqualTo(b, d); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[3][3]; + initGraph(expected, 3); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2)); + } + + // t1.a = 1, t1.b = 1, t1.c = 1 -> t1.a = 1, t1.b = 1, t1.c = 1 + @Test + public void testChooseEqualPredicatesSameTable2() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal d = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, d); + ComparisonPredicate eq2 = new EqualTo(b, d); + ComparisonPredicate eq3 = new EqualTo(c, d); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + inputs.add(eq3); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[4][4]; + initGraph(expected, 4); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[0][3] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + expected[3][0] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2) + && equalWithLiteral.contains(3)); + } + + // t1.a = 1, t1.b = t1.a -> t1.a = 1, t1.b = 1 + @Test + public void testChooseEqualPredicatesSameTable3() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal d = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, d); + ComparisonPredicate eq2 = new EqualTo(b, a); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[3][3]; + initGraph(expected, 3); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2)); + } + + // t1.a = 1, t1.b = t1.a, t1.a = t1.c -> t1.a = 1, t1.b = 1, t1.c = 1 + @Test + public void testChooseEqualPredicatesSameTable4() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal d = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, d); + ComparisonPredicate eq2 = new EqualTo(b, a); + ComparisonPredicate eq3 = new EqualTo(c, a); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + inputs.add(eq3); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[4][4]; + initGraph(expected, 4); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[0][3] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + expected[3][0] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2) + && equalWithLiteral.contains(3)); + } + + // t1.a = 1, t1.b = t1.a, t1.d = t1.c -> t1.a = 1, t1.b = 1, t1.c = t1.d + @Test + public void testChooseEqualPredicatesSameTable5() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference d = new SlotReference("d", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal literal = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, literal); + ComparisonPredicate eq2 = new EqualTo(b, a); + ComparisonPredicate eq3 = new EqualTo(d, c); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + inputs.add(eq3); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[5][5]; + initGraph(expected, 5); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[3][4] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + expected[4][3] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2)); + } + + @Test + // t1.a = 1, t2.b = 1 -> t1.a = 1, t2.b = 1 + public void testChooseEqualPredicatesDiffTable1() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + Literal d = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, d); + ComparisonPredicate eq2 = new EqualTo(b, d); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[3][3]; + initGraph(expected, 3); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2)); + } + + // t1.a = 1, t2.b = 1, t3.c = 1 -> t1.a = 1, t2.b = 1, t2.c = 1 + @Test + public void testChooseEqualPredicatesDiffTable2() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t3")); + Literal d = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, d); + ComparisonPredicate eq2 = new EqualTo(b, d); + ComparisonPredicate eq3 = new EqualTo(c, d); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + inputs.add(eq3); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[4][4]; + initGraph(expected, 4); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[0][3] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + expected[3][0] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2) + && equalWithLiteral.contains(3)); + } + + // t1.a = 1, t2.b = t1.a, t1.a = t3.c -> t1.a = 1, t2.b = 1, t3.c = 1 + @Test + public void testChooseEqualPredicatesDiffTable3() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t3")); + Literal d = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, d); + ComparisonPredicate eq2 = new EqualTo(b, a); + ComparisonPredicate eq3 = new EqualTo(c, a); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + inputs.add(eq3); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[4][4]; + initGraph(expected, 4); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[0][3] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + expected[3][0] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2) + && equalWithLiteral.contains(3)); + } + + // t1.a = 1, t2.b = t1.a, t4.d = t3.c -> t1.a = 1, t2.b = 1, t4.d = t3.c + @Test + public void testChooseEqualPredicatesDiffTable5() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t3")); + SlotReference d = new SlotReference("d", IntegerType.INSTANCE, true, ImmutableList.of("t4")); + Literal literal = new IntegerLiteral(1); + ComparisonPredicate eq1 = new EqualTo(a, literal); + ComparisonPredicate eq2 = new EqualTo(b, a); + ComparisonPredicate eq3 = new EqualTo(d, c); + Set inputs = new LinkedHashSet<>(); + inputs.add(eq1); + inputs.add(eq2); + inputs.add(eq3); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithLiteral = new HashSet<>(); + Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithLiteral); + Relation[][] expected = new Relation[5][5]; + initGraph(expected, 5); + expected[0][1] = Relation.EQ; + expected[0][2] = Relation.EQ; + expected[1][0] = Relation.EQ; + expected[2][0] = Relation.EQ; + expected[3][4] = Relation.EQ; + expected[4][3] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + Assertions.assertTrue(equalWithLiteral.contains(1) && equalWithLiteral.contains(2)); + Set chosenInputs = inferenceGraph.chooseInputPredicates(chosen); + // expected[3][4] (t1.d=t1.c) choose in chooseInputPredicates + Assertions.assertTrue(chosenInputs.contains(eq3)); + } + + // a>1 b>a -> a>1 b>a + @Test + public void testChooseUnequalPredicatesSameTable1() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal literal = new IntegerLiteral(1); + ComparisonPredicate cmp1 = new GreaterThan(a, literal); + ComparisonPredicate cmp2 = new GreaterThan(b, a); + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + Set sets = UnequalPredicateInfer.inferUnequalPredicates(inputs); + Assertions.assertEquals(2, sets.size()); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithConstant = new HashSet<>(); + InferenceGraph.Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithConstant); + inferenceGraph.chooseUnequalPredicates(chosen, equalWithConstant); + Relation[][] expected = new Relation[3][3]; + initGraph(expected, 3); + expected[1][0] = Relation.GT; + expected[2][1] = Relation.GT; + assert2DArrayEquals(expected, chosen); + } + + // a<1 b=a -> b<1 b=a + @Test + public void testChooseUnequalPredicatesSameTable2() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal literal = new IntegerLiteral(1); + ComparisonPredicate cmp1 = new LessThan(a, literal); + ComparisonPredicate cmp2 = new EqualTo(b, a); + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + Set sets = UnequalPredicateInfer.inferUnequalPredicates(inputs); + Assertions.assertEquals(2, sets.size()); + Assertions.assertTrue(sets.contains(new LessThan(b, literal)) && sets.contains(cmp2)); + for (Expression e : sets) { + if (e.equals(cmp2)) { + Assertions.assertFalse(e.isInferred()); + } else { + Assertions.assertTrue(e.isInferred()); + } + } + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithConstant = new HashSet<>(); + InferenceGraph.Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithConstant); + inferenceGraph.chooseUnequalPredicates(chosen, equalWithConstant); + Relation[][] expected = new Relation[3][3]; + initGraph(expected, 3); + expected[1][2] = Relation.EQ; + expected[2][1] = Relation.EQ; + expected[0][2] = Relation.GT; + assert2DArrayEquals(expected, chosen); + } + + // t1.a>1 t1.b>t1.a -> t1.a>1,t1.b>1,t1.b>t1.a + @Test + public void testChooseUnequalPredicatesDiffTable1() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + Literal literal = new IntegerLiteral(1); + ComparisonPredicate cmp1 = new GreaterThan(a, literal); + ComparisonPredicate cmp2 = new GreaterThan(b, a); + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + Set sets = UnequalPredicateInfer.inferUnequalPredicates(inputs); + Assertions.assertEquals(3, sets.size()); + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithConstant = new HashSet<>(); + InferenceGraph.Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithConstant); + inferenceGraph.chooseUnequalPredicates(chosen, equalWithConstant); + Relation[][] expected = new Relation[3][3]; + initGraph(expected, 3); + // t1.a>1,t1.b>1 is chosen in chooseUnequalPredicates + expected[1][0] = Relation.GT; + expected[2][0] = Relation.GT; + assert2DArrayEquals(expected, chosen); + } + + // t1.a<1 t2.b=t1.a -> t2.b<1 t2.a<1 t2.b=t1.a + @Test + public void testChooseUnequalPredicatesDiffTable2() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + Literal literal = new IntegerLiteral(1); + ComparisonPredicate cmp1 = new LessThan(b, literal); + ComparisonPredicate cmp2 = new EqualTo(b, a); + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + Set sets = UnequalPredicateInfer.inferUnequalPredicates(inputs); + Assertions.assertEquals(3, sets.size()); + Assertions.assertTrue(sets.contains(new LessThan(b, literal)) && sets.contains(cmp2) && sets.contains(cmp1)); + for (Expression e : sets) { + if (e.equals(cmp1) || e.equals(cmp2)) { + Assertions.assertFalse(e.isInferred()); + } else { + Assertions.assertTrue(e.isInferred()); + } + } + InferenceGraph inferenceGraph = new InferenceGraph(inputs); + inferenceGraph.deduce(inferenceGraph.getGraph()); + Set equalWithConstant = new HashSet<>(); + InferenceGraph.Relation[][] chosen = inferenceGraph.chooseEqualPredicates(equalWithConstant); + inferenceGraph.chooseUnequalPredicates(chosen, equalWithConstant); + Relation[][] expected = new Relation[3][3]; + initGraph(expected, 3); + expected[0][2] = Relation.GT; + expected[0][1] = Relation.GT; + expected[1][2] = Relation.EQ; + expected[2][1] = Relation.EQ; + assert2DArrayEquals(expected, chosen); + } + + // t1.a=t2.b t1.a=t3.c t2.b=t3.c -> t1.a=t2.b t1.a=t3.c t2.b=t3.c + @Test + public void testInferWithTransitiveEqualityDifferentTableThreeConjuncts1() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t3")); + ComparisonPredicate cmp1 = new EqualTo(a, b); + ComparisonPredicate cmp2 = new EqualTo(a, c); + ComparisonPredicate cmp3 = new EqualTo(b, c); + + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + inputs.add(cmp3); + Set sets = UnequalPredicateInfer.inferUnequalPredicates(inputs); + Assertions.assertEquals(3, sets.size()); + Assertions.assertTrue(sets.contains(cmp1) && sets.contains(cmp2) && sets.contains(cmp3)); + } + + // t1.a=t3.c t1.a=t2.b t1.b=t3.c -> t1.a=t2.b t1.a=t3.c t2.b=t3.c + @Test + public void testInferWithTransitiveEqualityDifferentTableTwoConjuncts() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t3")); + ComparisonPredicate cmp1 = new EqualTo(a, c); + ComparisonPredicate cmp2 = new EqualTo(a, b); + ComparisonPredicate cmp3 = new EqualTo(b, c); + + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + Set sets = UnequalPredicateInfer.inferUnequalPredicates(inputs); + Assertions.assertEquals(3, sets.size()); + Assertions.assertTrue(sets.contains(cmp1) && sets.contains(cmp2) && sets.contains(cmp3)); + } + + // t1.a=t3.c t1.a=t2.b t1.b=t3.c -> t1.a=t2.b t1.a=t3.c t2.b=t3.c + @Test + public void testUtilChooseMultiEquals() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t3")); + ComparisonPredicate cmp1 = new EqualTo(a, c); + ComparisonPredicate cmp2 = new EqualTo(a, b); + ComparisonPredicate cmp3 = new EqualTo(b, c); + + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + inputs.add(cmp3); + Set sets = PredicateInferUtils.inferPredicate(inputs); + Assertions.assertEquals(3, sets.size()); + Assertions.assertTrue(sets.contains(cmp1) && sets.contains(cmp2) && sets.contains(cmp3)); + } + + // t1.a=t3.c t1.a=t2.b -> t1.a=t2.b t1.a=t3.c t2.b=t3.c + @Test + public void testUtilChooseMultiEquals2() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE, true, ImmutableList.of("t3")); + ComparisonPredicate cmp1 = new EqualTo(a, c); + ComparisonPredicate cmp2 = new EqualTo(a, b); + ComparisonPredicate cmp3 = new EqualTo(b, c); + + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + Set sets = PredicateInferUtils.inferPredicate(inputs); + Assertions.assertEquals(3, sets.size()); + Assertions.assertTrue(sets.contains(cmp1) && sets.contains(cmp2) && sets.contains(cmp3)); + } + + @Test + public void testPredicateUtils() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE, true, ImmutableList.of("t1")); + Literal literal = new IntegerLiteral(1); + ComparisonPredicate cmp1 = new LessThan(a, literal); + ComparisonPredicate cmp2 = new EqualTo(b, a); + Set inputs = new LinkedHashSet<>(); + inputs.add(cmp1); + inputs.add(cmp2); + Set sets = PredicateInferUtils.inferPredicate(inputs); + Assertions.assertEquals(2, sets.size()); + Assertions.assertTrue(sets.contains(new LessThan(b, literal)) && sets.contains(cmp2)); + for (Expression e : sets) { + if (e.equals(cmp2)) { + Assertions.assertFalse(e.isInferred()); + } else { + Assertions.assertTrue(e.isInferred()); + } + } + } + + @Test + public void testInferWithTransitiveEqualityWithCastDateToDateTime() { + // cast(d_datev2 as datetime) = cast(d_datev2 as datetime) + SlotReference a = new SlotReference("a", DateV2Type.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", DateV2Type.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", DateTimeType.INSTANCE, true, ImmutableList.of("t3")); + EqualTo equalTo1 = new EqualTo(new Cast(a, DateTimeType.INSTANCE), c); + EqualTo equalTo2 = new EqualTo(new Cast(b, DateTimeType.INSTANCE), c); + Set inputs = new HashSet<>(); + inputs.add(equalTo1); + inputs.add(equalTo2); + Set result = UnequalPredicateInfer.inferUnequalPredicates(inputs); + EqualTo expected = new EqualTo(a, b); + Assertions.assertTrue(result.contains(expected) || result.contains(expected.commute()), "Expected to find a = b in the result."); + } + + @Test + public void testInferWithTransitiveEqualityWithCastDatev2andDate() { + // cast(d_datev2 as date) = cast(d_date as d_datev2) + SlotReference a = new SlotReference("a", DateV2Type.INSTANCE, true, ImmutableList.of("t1")); + SlotReference b = new SlotReference("b", DateV2Type.INSTANCE, true, ImmutableList.of("t2")); + SlotReference c = new SlotReference("c", DateType.INSTANCE, true, ImmutableList.of("t3")); + EqualTo equalTo1 = new EqualTo(new Cast(a, DateType.INSTANCE), c); + EqualTo equalTo2 = new EqualTo(b, new Cast(c, DateV2Type.INSTANCE)); + + Set inputs = new HashSet<>(); + inputs.add(equalTo1); + inputs.add(equalTo2); + Set result = UnequalPredicateInfer.inferUnequalPredicates(inputs); + EqualTo expected = new EqualTo(a, b); + Assertions.assertTrue(result.contains(expected) || result.contains(expected.commute()), "Expected to find a = b in the result."); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/policy/PolicyTest.java b/fe/fe-core/src/test/java/org/apache/doris/policy/PolicyTest.java index f803dc10563193..7d48be4da9ee99 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/policy/PolicyTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/policy/PolicyTest.java @@ -222,13 +222,17 @@ public void testComplexSql() throws Exception { createPolicy("CREATE ROW POLICY test_row_policy1 ON test.table1 AS RESTRICTIVE TO test_policy USING (k1 = 1)"); createPolicy("CREATE ROW POLICY test_row_policy2 ON test.table1 AS RESTRICTIVE TO test_policy USING (k2 = 1)"); String joinSql = "select * from table1 join table2 on table1.k1=table2.k1"; - Assertions.assertTrue(getSQLPlanOrErrorMsg(joinSql).contains("PREDICATES: ((k1 = 1) AND (k2 = 1))")); + Assertions.assertTrue(getSQLPlanOrErrorMsg(joinSql).contains("PREDICATES: ((k2 = 1) AND (k1 = 1))") + || getSQLPlanOrErrorMsg(joinSql).contains("PREDICATES: ((k1 = 1) AND (k2 = 1))")); String unionSql = "select * from table1 union select * from table2"; - Assertions.assertTrue(getSQLPlanOrErrorMsg(unionSql).contains("PREDICATES: ((k1 = 1) AND (k2 = 1))")); + Assertions.assertTrue(getSQLPlanOrErrorMsg(unionSql).contains("PREDICATES: ((k2 = 1) AND (k1 = 1))") + || getSQLPlanOrErrorMsg(joinSql).contains("PREDICATES: ((k1 = 1) AND (k2 = 1))")); String subQuerySql = "select * from table2 where k1 in (select k1 from table1)"; - Assertions.assertTrue(getSQLPlanOrErrorMsg(subQuerySql).contains("PREDICATES: ((k1 = 1) AND (k2 = 1))")); + Assertions.assertTrue(getSQLPlanOrErrorMsg(subQuerySql).contains("PREDICATES: ((k2 = 1) AND (k1 = 1))") + || getSQLPlanOrErrorMsg(joinSql).contains("PREDICATES: ((k1 = 1) AND (k2 = 1))")); String aliasSql = "select * from table1 t1 join table2 t2 on t1.k1=t2.k1"; - Assertions.assertTrue(getSQLPlanOrErrorMsg(aliasSql).contains("PREDICATES: ((k1 = 1) AND (k2 = 1))")); + Assertions.assertTrue(getSQLPlanOrErrorMsg(aliasSql).contains("PREDICATES: ((k2 = 1) AND (k1 = 1))") + || getSQLPlanOrErrorMsg(joinSql).contains("PREDICATES: ((k1 = 1) AND (k2 = 1))")); dropPolicy("DROP ROW POLICY test_row_policy1 ON test.table1"); dropPolicy("DROP ROW POLICY test_row_policy2 ON test.table1"); } diff --git a/regression-test/data/nereids_hint_tpch_p0/shape/q12.out b/regression-test/data/nereids_hint_tpch_p0/shape/q12.out index ad76dd8bd9f453..a8710941069079 100644 --- a/regression-test/data/nereids_hint_tpch_p0/shape/q12.out +++ b/regression-test/data/nereids_hint_tpch_p0/shape/q12.out @@ -12,7 +12,7 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] Hint log: diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out new file mode 100644 index 00000000000000..6976dd752a6de9 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.out @@ -0,0 +1,686 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !test_integer_cast -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = expr_cast(d_tinyint as INT))) otherCondition=() +----filter((t1.d_tinyint < 10)) +------PhysicalOlapScan[extend_infer_t1] +----PhysicalOlapScan[extend_infer_t1] + +-- !test_simple_compare -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----filter((t1.d_int < 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_int < 10)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_simple_compare_not_equal -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_simple_compare_datetimev2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datetimev2 = t2.d_datetimev2)) otherCondition=() +----filter((t1.d_datetimev2 = '2024-01-01 00:00:00')) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_datetimev2 = '2024-01-01 00:00:00')) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_simple_compare_not_equal_datetimev2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datetimev2 = t2.d_datetimev2)) otherCondition=() +----filter(( not (d_datetimev2 = '2024-01-01 00:00:00'))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not (d_datetimev2 = '2024-01-01 00:00:00'))) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_not_in -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----filter(( not d_int IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not d_int IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_in -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----filter(d_int IN (10, 20)) +------PhysicalOlapScan[extend_infer_t1] +----filter(d_int IN (10, 20)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_func_not_in -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----filter(( not abs(d_int) IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not abs(d_int) IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_like -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_char100 = t2.d_char100)) otherCondition=() +----filter((d_char100 like '012%')) +------PhysicalOlapScan[extend_infer_t1] +----filter((d_char100 like '012%')) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_like_not -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_char100 = t2.d_char100)) otherCondition=() +----PhysicalOlapScan[extend_infer_t1] +----filter(( not (d_char100 like '012%'))) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_like_to_equal -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_char100 = t2.d_char100)) otherCondition=() +----filter((t1.d_char100 = '012')) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_char100 = '012')) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_func_not_in_and_func_equal_condition -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_abs(d_int) = expr_abs(d_int))) otherCondition=() +----filter(( not abs(d_int) IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not abs(d_int) IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_between_and -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +----filter((t1.a <= 10) and (t1.a >= 1)) +------PhysicalOlapScan[extend_infer_t3] +----filter((t2.a <= 10) and (t2.a >= 1)) +------PhysicalOlapScan[extend_infer_t4] + +-- !test_and -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +----filter((t1.a <= 10) and (t1.a >= 2)) +------PhysicalOlapScan[extend_infer_t3] +----filter((t2.a <= 10) and (t2.a >= 2)) +------PhysicalOlapScan[extend_infer_t4] + +-- !test_or1 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +----filter(((t1.a < 2) OR (t1.a > 10))) +------PhysicalOlapScan[extend_infer_t3] +----PhysicalOlapScan[extend_infer_t4] + +-- !test_or2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +----filter(((t1.a < 2) OR (t1.a > 10))) +------PhysicalOlapScan[extend_infer_t3] +----PhysicalOlapScan[extend_infer_t4] + +-- !test_sign_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +----filter((sign(cast(a as DOUBLE)) >= 1)) +------PhysicalOlapScan[extend_infer_t3] +----filter((sign(cast(a as DOUBLE)) >= 1)) +------PhysicalOlapScan[extend_infer_t4] + +-- !test_if_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----PhysicalOlapScan[extend_infer_t1] +----filter(if(( not d_int IN (10, 20)), TRUE, FALSE)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_if_and_in_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_if_and_in_predicate_not -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not (if((d_int = 5), TRUE, FALSE) = FALSE))) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_multi_slot_in_predicate1 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_(a + c) = expr_(a + c))) otherCondition=() +----filter(((t1.a + t1.c) < 10)) +------PhysicalOlapScan[extend_infer_t3] +----filter(((t2.a + t2.c) < 10)) +------PhysicalOlapScan[extend_infer_t4] + +-- !test_multi_slot_in_predicate2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a) and (t1.b = t2.b)) otherCondition=() +----filter(((cast(a as DOUBLE) + cast(b as DOUBLE)) < 10.0)) +------PhysicalOlapScan[extend_infer_t3] +----PhysicalOlapScan[extend_infer_t4] + +-- !test_case_when_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +----PhysicalOlapScan[extend_infer_t1] +----filter(CASE WHEN (d_int = 1) THEN TRUE WHEN (d_int = 2) THEN FALSE ELSE FALSE END) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_datetimev2_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datetimev2 = t2.d_datetimev2)) otherCondition=() +----filter((convert_tz(date_trunc(d_datetimev2, 'month'), 'Asia/Shanghai', 'Europe/Paris') = '2024-01-01 00:00:00')) +------PhysicalOlapScan[extend_infer_t1] +----filter((convert_tz(date_trunc(d_datetimev2, 'month'), 'Asia/Shanghai', 'Europe/Paris') = '2024-01-01 00:00:00')) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_convert_tz_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datetimev2 = t2.d_datetimev2)) otherCondition=() +----filter((convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris') > '2022-01-01 00:00:00')) +------PhysicalOlapScan[extend_infer_t1] +----filter((convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris') > '2022-01-01 00:00:00')) +------PhysicalOlapScan[extend_infer_t2] + +-- !test_next_date_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datetimev2 = t2.d_datetimev2)) otherCondition=() +----filter((dayofmonth(hours_add(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'), 10)) > 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((dayofmonth(hours_add(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'), 10)) > 10)) +------PhysicalOlapScan[extend_infer_t2] + +-- !test_random_nest_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datetimev2 = t2.d_datetimev2)) otherCondition=() +----filter((dayofmonth(hours_add(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'), cast(random(1, 10) as INT))) > 10)) +------PhysicalOlapScan[extend_infer_t1] +----PhysicalOlapScan[extend_infer_t2] + +-- !test_random_predicate -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +----filter((cast(a as DOUBLE) > random(10))) +------PhysicalOlapScan[extend_infer_t3] +----PhysicalOlapScan[extend_infer_t4] + +-- !test_predicate_map -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datetimev2 = t2.d_datetimev2)) otherCondition=() +----filter((convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris') < '2022-01-01 00:00:00') and (dayofmonth(hours_add(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'), 10)) > 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris') < '2022-01-01 00:00:00') and (dayofmonth(hours_add(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'), 10)) > 10)) +------PhysicalOlapScan[extend_infer_t2] + +-- !test_int_upcast -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = expr_cast(d_tinyint as INT))) otherCondition=() +----filter((t1.d_int < 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((cast(d_tinyint as INT) < 10) and (t2.d_tinyint < 10)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_int_downcast -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_int as TINYINT) = t2.d_tinyint)) otherCondition=() +----filter((cast(d_int as TINYINT) < 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_tinyint < 10)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_date_upcast -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_datev2 as DATETIMEV2(0)) = t2.d_datetimev2)) otherCondition=() +----filter((t1.d_datev2 < '2022-01-03')) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_datetimev2 < '2022-01-03 00:00:00')) +------PhysicalOlapScan[extend_infer_t2] + +-- !test_date_downcast -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datev2 = expr_cast(d_datetimev2 as DATEV2))) otherCondition=() +----filter((t1.d_datev2 < '2022-01-03')) +------PhysicalOlapScan[extend_infer_t1] +----filter((cast(d_datetimev2 as DATEV2) < '2022-01-03')) +------PhysicalOlapScan[extend_infer_t2] + +-- !test_date_both_upcast1 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_datev2 as DATETIMEV2(0)) = expr_cast(d_date as DATETIMEV2(0)))) otherCondition=() +----filter((t1.d_datev2 < '2022-01-03')) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_date < '2022-01-03')) +------PhysicalOlapScan[extend_infer_t2] + +-- !test_date_both_upcast2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datetime = expr_cast(d_date as DATETIMEV2(0)))) otherCondition=() +----filter((t1.d_datetime < '2022-01-03 00:00:00')) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_date < '2022-01-03')) +------PhysicalOlapScan[extend_infer_t2] + +-- !test_char_different_type1 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_char100 = t2.d_char10)) otherCondition=() +----filter((t1.d_char100 > 'abc')) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_char10 > 'abc')) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_char_different_type2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_substring(cast(d_char100 as CHAR(50)), 1, 50) = t2.d_char10)) otherCondition=() +----filter((substring(cast(d_char100 as CHAR(50)), 1, 50) > 'abc')) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_char10 > 'abc')) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_char_different_type3 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_substring(cast(d_char100 as CHAR(50)), 1, 50) = expr_substring(cast(d_char10 as CHAR(50)), 1, 50))) otherCondition=() +----PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_char10 > 'abc')) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_char_different_type4 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_substring(cast(d_char100 as CHAR(200)), 1, 200) = expr_substring(cast(d_char10 as CHAR(200)), 1, 200))) otherCondition=() +----PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_char10 > 'abc')) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_cast_and_func -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_abs(d_int) = expr_cast(d_tinyint as BIGINT))) otherCondition=() +----PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_tinyint < 10)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_cast_and_func2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_cast(abs(d_int) as TINYINT) = t2.d_tinyint)) otherCondition=() +----filter((cast(abs(d_int) as TINYINT) < 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((t2.d_tinyint < 10)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_cast_and_func3 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_cast(cast(d_int as TINYINT) as SMALLINT) = expr_abs(d_tinyint))) otherCondition=() +----filter((cast(cast(d_int as TINYINT) as SMALLINT) < 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((abs(d_tinyint) < 10)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_cast_and_func4 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_int = expr_cast(abs(d_tinyint) as INT))) otherCondition=() +----PhysicalOlapScan[extend_infer_t1] +----filter((abs(d_tinyint) < 10)) +------PhysicalOlapScan[extend_infer_t1] + +-- !test_func_equal_and_nest_func_pred1 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris') = expr_convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'))) otherCondition=() +----filter((dayofmonth(hours_add(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'), 10)) > 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((dayofmonth(hours_add(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'), 10)) > 10)) +------PhysicalOlapScan[extend_infer_t2] + +-- !test_func_equal_and_nest_func_pred2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris') = expr_convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris'))) otherCondition=() +----filter((dayofmonth(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris')) > 10)) +------PhysicalOlapScan[extend_infer_t1] +----filter((dayofmonth(convert_tz(d_datetimev2, 'Asia/Shanghai', 'Europe/Paris')) > 10)) +------PhysicalOlapScan[extend_infer_t2] + +-- !predicate_to_empty_relation -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t3.a)) otherCondition=() +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.a = t2.a)) otherCondition=() +------filter((t1.a = 2)) +--------PhysicalOlapScan[extend_infer_t3] +------PhysicalEmptyRelation +----filter((t3.a = 2)) +------PhysicalOlapScan[extend_infer_t4] + +-- !equal_table_predicate_delete -- +PhysicalResultSink +--filter((extend_infer_t3.a = 1) and (extend_infer_t3.c = 1)) +----PhysicalOlapScan[extend_infer_t3] + +-- !test_integer_cast_res -- + +-- !test_simple_compare_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 + +-- !test_simple_compare_not_equal_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_simple_compare_datetimev2_res -- + +-- !test_simple_compare_not_equal_datetimev2_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_not_in_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_in_res -- + +-- !test_func_not_in_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_like_res -- + +-- !test_like_not_res -- + +-- !test_like_to_equal_res -- + +-- !test_func_not_in_and_func_equal_condition_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_between_and_res -- +1 d2 3 5 1 d2 2 2 + +-- !test_and_res -- + +-- !test_or1_res -- +1 d2 3 5 1 d2 2 2 + +-- !test_or2_res -- +1 d2 3 5 1 d2 2 2 + +-- !test_sign_predicate_res -- +1 d2 3 5 1 d2 2 2 + +-- !test_if_predicate_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_if_and_in_predicate_res -- + +-- !test_if_and_in_predicate_not_res -- + +-- !test_multi_slot_in_predicate1_res -- +0 d2 3 5 1 d2 2 2 + +-- !test_multi_slot_in_predicate2_res -- + +-- !test_case_when_predicate_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 + +-- !test_datetimev2_predicate_res -- + +-- !test_convert_tz_predicate_res -- + +-- !test_next_date_predicate_res -- +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_random_nest_predicate_res -- +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_random_predicate_res -- +1 d2 3 5 1 d2 2 2 + +-- !test_predicate_map_res -- +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_int_upcast_res -- + +-- !test_int_downcast_res -- + +-- !test_date_upcast_res -- + +-- !test_date_downcast_res -- +1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 1 01234567890123456789 3 3 0123456789 2020-01-09T10:00:01 2020-01-09 2022-08-09 2022-08-09T10:00 +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_date_both_upcast1_res -- + +-- !test_date_both_upcast2_res -- + +-- !test_char_different_type1_res -- + +-- !test_char_different_type2_res -- + +-- !test_char_different_type3_res -- + +-- !test_char_different_type4_res -- + +-- !test_cast_and_func_res -- + +-- !test_cast_and_func2_res -- + +-- !test_cast_and_func3_res -- + +-- !test_cast_and_func4_res -- + +-- !test_func_equal_and_nest_func_pred1_res -- +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !test_func_equal_and_nest_func_pred2_res -- +14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 14 01234567890123456789 33 23 0123456789 2020-01-11T10:00:01 2020-01-11 2022-08-03 2022-08-09T10:00:02 + +-- !predicate_to_empty_relation_res -- + +-- !equal_table_predicate_delete_res -- + +-- !not_equal_inner_left -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((c1 = t2.d_int)) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_equal_inner_left2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[INNER_JOIN] hashCondition=((t1.d_int = c1)) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_equal_left_inner -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[INNER_JOIN] hashCondition=((c1 = t2.d_int)) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_equal_left_left -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((c1 = t2.d_int)) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_equal_left_left2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[INNER_JOIN] hashCondition=((t1.d_int = c1)) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_in_inner_right -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not d_int IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[INNER_JOIN] hashCondition=((c1 = t2.d_int)) otherCondition=() +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_in_inner_right2 -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not d_int IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[RIGHT_OUTER_JOIN] hashCondition=((t1.d_int = c1)) otherCondition=() +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_in_right_inner -- +PhysicalResultSink +--hashJoin[RIGHT_OUTER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not d_int IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[INNER_JOIN] hashCondition=((c1 = t2.d_int)) otherCondition=() +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_in_right_right -- +PhysicalResultSink +--hashJoin[RIGHT_OUTER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not d_int IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[INNER_JOIN] hashCondition=((c1 = t2.d_int)) otherCondition=() +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_in_right_right2 -- +PhysicalResultSink +--hashJoin[RIGHT_OUTER_JOIN] hashCondition=((t3.d_int = t.c1)) otherCondition=() +----filter(( not d_int IN (10, 20))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[RIGHT_OUTER_JOIN] hashCondition=((t1.d_int = c1)) otherCondition=() +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not d_int IN (10, 20))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_equal_semi_semi_with_cast -- +PhysicalResultSink +--hashJoin[LEFT_SEMI_JOIN] hashCondition=((expr_cast(d_smallint as INT) = t.c1)) otherCondition=() +----filter(( not (cast(d_smallint as INT) = 10)) and ( not (d_smallint = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[LEFT_SEMI_JOIN] hashCondition=((c1 = expr_cast(d_tinyint as INT))) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (cast(d_tinyint as INT) = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_equal_anti_anti_with_cast -- +PhysicalResultSink +--hashJoin[LEFT_ANTI_JOIN] hashCondition=((expr_cast(d_smallint as INT) = t.c1)) otherCondition=() +----filter(( not (d_smallint = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[LEFT_ANTI_JOIN] hashCondition=((c1 = expr_cast(d_tinyint as INT))) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (cast(d_tinyint as INT) = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_equal_anti_left_with_cast -- +PhysicalResultSink +--hashJoin[LEFT_ANTI_JOIN] hashCondition=((expr_cast(d_smallint as INT) = t.c1)) otherCondition=() +----filter(( not (d_smallint = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[LEFT_OUTER_JOIN] hashCondition=((c1 = expr_cast(d_tinyint as INT))) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (cast(d_tinyint as INT) = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !not_equal_semi_anti_with_cast -- +PhysicalResultSink +--hashJoin[LEFT_SEMI_JOIN] hashCondition=((expr_cast(d_smallint as INT) = t.c1)) otherCondition=() +----filter(( not (cast(d_smallint as INT) = 10)) and ( not (d_smallint = 10))) +------PhysicalOlapScan[extend_infer_t1] +----hashJoin[LEFT_ANTI_JOIN] hashCondition=((c1 = expr_cast(d_tinyint as INT))) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (cast(d_tinyint as INT) = 10))) +--------PhysicalOlapScan[extend_infer_t1] + +-- !in_subquery_to_semi_join -- +PhysicalResultSink +--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.d_int = extend_infer_t2.d_int)) otherCondition=() +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t2] + +-- !not_in_subquery_to_na_anti_join_not_infer -- +PhysicalResultSink +--hashJoin[NULL_AWARE_LEFT_ANTI_JOIN] hashCondition=((t1.d_int = extend_infer_t2.d_int)) otherCondition=() +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t1] +----PhysicalOlapScan[extend_infer_t2] + +-- !in_subquery_to_semi_join -- +PhysicalResultSink +--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.d_int = extend_infer_t2.d_int)) otherCondition=() +----hashJoin[INNER_JOIN] hashCondition=((t1.d_int = t2.d_int)) otherCondition=() +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +------filter(( not (d_int = 10))) +--------PhysicalOlapScan[extend_infer_t1] +----filter(( not (d_int = 10))) +------PhysicalOlapScan[extend_infer_t2] + +-- !cast_to_decimal_overflow_not_infer -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_tinyint as INT) = t2.d_int)) otherCondition=() +----filter(cast(d_tinyint as DECIMALV3(4, 1)) IN (0.1, 0.5)) +------PhysicalOlapScan[extend_infer_t1] +----PhysicalOlapScan[extend_infer_t2] + +-- !char_equal_int_infer -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((expr_cast(d_char10 as DOUBLE) = expr_cast(d_int as DOUBLE))) otherCondition=() +----filter(d_char10 IN ('bb', 'd')) +------PhysicalOlapScan[extend_infer_t1] +----PhysicalOlapScan[extend_infer_t2] + +-- !date_equal_int_infer -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d_datev2 = expr_cast(d_int as DATEV2))) otherCondition=() +----filter(d_datev2 IN ('2024-01-01', '2024-01-02')) +------PhysicalOlapScan[extend_infer_t1] +----filter(cast(d_int as DATEV2) IN ('2024-01-01', '2024-01-02')) +------PhysicalOlapScan[extend_infer_t2] + diff --git a/regression-test/data/nereids_rules_p0/infer_predicate/infer_unequal_predicates.out b/regression-test/data/nereids_rules_p0/infer_predicate/infer_unequal_predicates.out new file mode 100644 index 00000000000000..30e82ec957c3c3 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/infer_predicate/infer_unequal_predicates.out @@ -0,0 +1,165 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !not_infer_same_table_have_mid_column -- +PhysicalResultSink +--filter((t1.a < 5) and (t1.c < t1.a)) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !not_infer_same_table_have_mid_literal -- +PhysicalResultSink +--filter((t1.a > 1) and (t1.c < 1)) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !not_infer_diff_table_have_mid_literal -- +PhysicalResultSink +--NestedLoopJoin[CROSS_JOIN] +----filter((t1.a < 1)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.a > 1)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !infer_diff_table -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN](t1.c < t2.a) +----filter((t1.c < 1)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.a < 1)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !should_infer_because_a_is_key -- +PhysicalResultSink +--filter((t1.a < 5) and (t1.a < t1.c) and (t1.c < 5)) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !should_infer_because_d_is_partition_column -- +PhysicalResultSink +--filter((t1.c < 10) and (t1.d < 10) and (t1.d < t1.c)) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !infer_with_equal -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.c)) otherCondition=() +----filter((t1.a < 1)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.c < 1)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !infer_4_expr -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN](t1.a < t2.a) +----filter((t1.a < 1)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.a < 1) and (t2.a < t2.c) and (t2.c < 1)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !infer_long_chain_same_table_infer_a_and_d -- +PhysicalResultSink +--filter((t1.a < 10) and (t1.a < t1.d) and (t1.c < 10) and (t1.d < 10) and (t1.d < t1.c)) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !infer_long_chain_same_table_not_infer_c -- +PhysicalResultSink +--filter((t1.a < 10) and (t1.a < t1.c) and (t1.c < t1.d) and (t1.d < 10)) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !remove_useless_input_predicate_c_less_than_10 -- +PhysicalResultSink +--filter((t1.a < 10) and (t1.a < t1.c) and (t1.c < t1.d) and (t1.d < 10)) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !remove_useless_predicate -- +PhysicalResultSink +--NestedLoopJoin[CROSS_JOIN] +----filter((t1.a = t1.c) and (t1.a > 1)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !infer_long_chain_diff_table -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN](t1.a < t2.d) +----filter((t1.a < 10)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.c < 10) and (t2.d < 10) and (t2.d < t2.c)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !infer_with_constant_and_columns -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.a = t2.c)) otherCondition=() +----filter((t1.a > 1)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.c < t2.d) and (t2.c > 1) and (t2.d > 1)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !no_infer -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN](t1.a < t2.d) +----PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.d > t2.c)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !no_infer_cyclic_dependency -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN](t1.a < t2.c)(t2.c < t1.a) +----PhysicalOlapScan[infer_unequal_predicates_t1] +----PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !infer_multiple_conditions -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN](t1.a < t2.a) +----filter((t1.a < 10)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.a < 10) and (t2.a < t2.c) and (t2.c < t2.d) and (t2.d < 10)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !infer_cast_int -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN](t1.d_int > cast(d_smallint as INT)) +----filter((t1.d_int > 1)) +------PhysicalOlapScan[infer_unequal_predicates_t3] +----filter((t2.d_smallint > 1)) +------PhysicalOlapScan[infer_unequal_predicates_t3] + +-- !multi_slot_equal -- +PhysicalResultSink +--filter((infer_unequal_predicates_t1.a = infer_unequal_predicates_t1.c) and (infer_unequal_predicates_t1.a = infer_unequal_predicates_t1.d)) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !no_redundant_predicates -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.d = t2.d)) otherCondition=() +----filter((t1.c > 1) and (t1.d < 10) and (t1.d = t1.c) and (t1.d > 1)) +------PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((t2.d < 10) and (t2.d > 1)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !expr_unequal_infer_same_table1 -- +PhysicalResultSink +--PhysicalEmptyRelation + +-- !expr_unequal_infer_same_table2 -- +PhysicalResultSink +--filter((abs(c) < 1) and (abs(d) < abs(c))) +----PhysicalOlapScan[infer_unequal_predicates_t1] + +-- !expr_unequal_infer_diff_table -- +PhysicalResultSink +--NestedLoopJoin[INNER_JOIN](abs(d) < abs(c)) +----PhysicalOlapScan[infer_unequal_predicates_t1] +----filter((abs(c) < 1)) +------PhysicalOlapScan[infer_unequal_predicates_t2] + +-- !not_infer_expr1 -- +PhysicalResultSink +--PhysicalEmptyRelation + +-- !not_infer_expr2 -- +PhysicalResultSink +--PhysicalEmptyRelation + +-- !not_infer_because_is_infer_and_then_remove -- +PhysicalResultSink +--PhysicalEmptyRelation + +-- !infer_join_equal_condition -- +PhysicalResultSink +--PhysicalEmptyRelation + diff --git a/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out b/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out index 15144b566b0474..14817af2ee3200 100644 --- a/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out +++ b/regression-test/data/nereids_rules_p0/predicate_infer/infer_predicate.out @@ -354,14 +354,14 @@ PhysicalResultSink PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((t12.id = t34.id)) otherCondition=() ----hashJoin[INNER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -------filter((t1.id < 9) and (t1.id > 1)) +------filter(( not (id = 3)) and ( not (id = 4)) and (t1.id < 9) and (t1.id > 1)) --------PhysicalOlapScan[t1] -------filter((t2.id < 9) and (t2.id > 1)) +------filter(( not (id = 3)) and ( not (id = 4)) and (t2.id < 9) and (t2.id > 1)) --------PhysicalOlapScan[t2] ----hashJoin[INNER_JOIN] hashCondition=((t3.id = t4.id)) otherCondition=() -------filter(( not (id = 3)) and (t34.id < 9) and (t34.id > 1)) +------filter(( not (id = 3)) and ( not (id = 4)) and (t34.id < 9) and (t34.id > 1)) --------PhysicalOlapScan[t3] -------filter(( not (id = 4)) and (t4.id < 9) and (t4.id > 1)) +------filter(( not (id = 3)) and ( not (id = 4)) and (t4.id < 9) and (t4.id > 1)) --------PhysicalOlapScan[t4] -- !infer8 -- @@ -384,8 +384,7 @@ PhysicalResultSink --hashJoin[INNER_JOIN] hashCondition=((expr_cast(id as SMALLINT) = expr_cast(id as SMALLINT))) otherCondition=() ----filter((cast(id as BIGINT) = 2147483648)) ------PhysicalOlapScan[t1] -----filter((cast(id as BIGINT) = 2147483648)) -------PhysicalOlapScan[t2] +----PhysicalOlapScan[t2] -- !infer11 -- PhysicalResultSink diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/nostats_rf_prune/q12.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/nostats_rf_prune/q12.out index 95a2108c4ae342..8df830dd428e58 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/nostats_rf_prune/q12.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/nostats_rf_prune/q12.out @@ -12,6 +12,6 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] apply RFs: RF0 ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/rf_prune/q12.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/rf_prune/q12.out index 95a2108c4ae342..8df830dd428e58 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/rf_prune/q12.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/rf_prune/q12.out @@ -12,6 +12,6 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] apply RFs: RF0 ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q12.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q12.out index 95a2108c4ae342..8df830dd428e58 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q12.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q12.out @@ -12,6 +12,6 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] apply RFs: RF0 ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape_no_stats/q12.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape_no_stats/q12.out index 95a2108c4ae342..8df830dd428e58 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape_no_stats/q12.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape_no_stats/q12.out @@ -12,6 +12,6 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] apply RFs: RF0 ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/data/new_shapes_p0/hint_tpch/shape/q12.out b/regression-test/data/new_shapes_p0/hint_tpch/shape/q12.out index ad76dd8bd9f453..a8710941069079 100644 --- a/regression-test/data/new_shapes_p0/hint_tpch/shape/q12.out +++ b/regression-test/data/new_shapes_p0/hint_tpch/shape/q12.out @@ -12,7 +12,7 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] Hint log: diff --git a/regression-test/data/new_shapes_p0/tpch_sf1000/nostats_rf_prune/q12.out b/regression-test/data/new_shapes_p0/tpch_sf1000/nostats_rf_prune/q12.out index 95a2108c4ae342..8df830dd428e58 100644 --- a/regression-test/data/new_shapes_p0/tpch_sf1000/nostats_rf_prune/q12.out +++ b/regression-test/data/new_shapes_p0/tpch_sf1000/nostats_rf_prune/q12.out @@ -12,6 +12,6 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] apply RFs: RF0 ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/data/new_shapes_p0/tpch_sf1000/rf_prune/q12.out b/regression-test/data/new_shapes_p0/tpch_sf1000/rf_prune/q12.out index 95a2108c4ae342..8df830dd428e58 100644 --- a/regression-test/data/new_shapes_p0/tpch_sf1000/rf_prune/q12.out +++ b/regression-test/data/new_shapes_p0/tpch_sf1000/rf_prune/q12.out @@ -12,6 +12,6 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] apply RFs: RF0 ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/data/new_shapes_p0/tpch_sf1000/shape/q12.out b/regression-test/data/new_shapes_p0/tpch_sf1000/shape/q12.out index 95a2108c4ae342..8df830dd428e58 100644 --- a/regression-test/data/new_shapes_p0/tpch_sf1000/shape/q12.out +++ b/regression-test/data/new_shapes_p0/tpch_sf1000/shape/q12.out @@ -12,6 +12,6 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] apply RFs: RF0 ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/data/new_shapes_p0/tpch_sf1000/shape_no_stats/q12.out b/regression-test/data/new_shapes_p0/tpch_sf1000/shape_no_stats/q12.out index 95a2108c4ae342..8df830dd428e58 100644 --- a/regression-test/data/new_shapes_p0/tpch_sf1000/shape_no_stats/q12.out +++ b/regression-test/data/new_shapes_p0/tpch_sf1000/shape_no_stats/q12.out @@ -12,6 +12,6 @@ PhysicalResultSink ------------------PhysicalProject --------------------PhysicalOlapScan[orders] apply RFs: RF0 ------------------PhysicalProject ---------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) +--------------------filter((lineitem.l_commitdate < lineitem.l_receiptdate) and (lineitem.l_receiptdate < '1995-01-01') and (lineitem.l_receiptdate >= '1994-01-01') and (lineitem.l_shipdate < '1995-01-01') and (lineitem.l_shipdate < lineitem.l_commitdate) and l_shipmode IN ('MAIL', 'SHIP')) ----------------------PhysicalOlapScan[lineitem] diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy index 6e200f70d5a3b1..8e7ecae59f98f5 100644 --- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -41,10 +41,11 @@ suite("test_infer_predicate") { contains "PREDICATES: (k2" } - explain { - sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;" - contains "PREDICATES: (CAST(k2" - } +// not support infer predicate downcast +// explain { +// sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;" +// contains "PREDICATES: (CAST(k2" +// } explain { sql "select * from infer_tb1 inner join infer_tb3 where infer_tb3.k1 = infer_tb1.k2 and infer_tb3.k1 = '123';" @@ -55,6 +56,9 @@ suite("test_infer_predicate") { sql "select * from infer_tb1 left join infer_tb2 on infer_tb1.k1 = infer_tb2.k3 left join infer_tb3 on " + "infer_tb2.k3 = infer_tb3.k2 where infer_tb1.k1 = 1;" contains "PREDICATES: (k3" - contains "PREDICATES: (k2" + // After modifying the logic of pull up predicates from join, the left join left table predicate will not be pulled up. + // left join left table predicates should not be pulled up. because there may be null value. + // However, in this case, pulling up seems to be OK, so note for now + // contains "PREDICATES: (k2" } } diff --git a/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy b/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy new file mode 100644 index 00000000000000..4b7b4bc504605a --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/infer_predicate/extend_infer_equal_predicate.groovy @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("extend_infer_equal_predicate") { + sql "set enable_fallback_to_original_planner=false" + sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'""" + sql 'set runtime_filter_mode=off' + sql 'set disable_join_reorder=true' + + sql """ + drop table if exists extend_infer_t1; + """ + sql """ + drop table if exists extend_infer_t2; + """ + sql """ + create table extend_infer_t1(d_int int, d_char100 char(100), d_smallint smallint, d_tinyint tinyint, d_char10 char(10),d_datetimev2 datetimev2, d_datev2 datev2,d_date date, d_datetime datetime) properties('replication_num'='1'); + """ + sql """ + create table extend_infer_t2(d_int int, d_char100 char(100), d_smallint smallint, d_tinyint tinyint, d_char10 char(10),d_datetimev2 datetimev2, d_datev2 datev2,d_date date, d_datetime datetime) properties('replication_num'='1'); + """ + sql """ + insert into extend_infer_t1 values(1,'01234567890123456789', 3,3,'0123456789','2020-01-09 10:00:00.99','2020-01-09','2022-08-09','2022-08-09 10:00:00'),(14,'01234567890123456789', 33,23,'0123456789','2020-01-11 10:00:00.99','2020-01-11','2022-08-03','2022-08-09 10:00:02'); + """ + sql """ + insert into extend_infer_t2 values(1,'01234567890123456789', 3,3,'0123456789','2020-01-09 10:00:00.99','2020-01-09','2022-08-09','2022-08-09 10:00:00'),(14,'01234567890123456789', 33,23,'0123456789','2020-01-11 10:00:00.99','2020-01-11','2022-08-03','2022-08-09 10:00:02'); + """ + + sql "drop table if exists extend_infer_t3;" + sql "drop table if exists extend_infer_t4;" + sql "drop table if exists extend_infer_t5;" + + sql """ + CREATE TABLE `extend_infer_t3` ( + `a` INT NULL, + `b` VARCHAR(10) NULL, + `c` INT NULL, + `d` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`a`, `b`) + DISTRIBUTED BY RANDOM BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE `extend_infer_t4` ( + `a` INT NULL, + `b` VARCHAR(10) NULL, + `c` INT NULL, + `d` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`a`, `b`) + DISTRIBUTED BY RANDOM BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE `extend_infer_t5` ( + `a` INT NULL, + `b` VARCHAR(10) NULL, + `c` INT NULL, + `d` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`a`, `b`) + DISTRIBUTED BY RANDOM BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + insert into extend_infer_t3 values(1,'d2',3,5); + """ + sql """ + insert into extend_infer_t4 values(1,'d2',2,2); + """ + sql """ + insert into extend_infer_t5 values(1,'d2',2,2); + """ + sql """ + insert into extend_infer_t4 values(-3,'d2',2,2); + """ + sql """ + insert into extend_infer_t3 values(0,'d2',3,5); + """ + + qt_test_integer_cast """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_tinyint where t1.d_tinyint<10;""" + qt_test_simple_compare """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t2.d_int<10""" + qt_test_simple_compare_not_equal """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t2.d_int!=10;""" + qt_test_simple_compare_datetimev2 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_datetimev2=t2.d_datetimev2 where t2.d_datetimev2='2024-01-01';""" + qt_test_simple_compare_not_equal_datetimev2 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_datetimev2=t2.d_datetimev2 where t2.d_datetimev2!='2024-01-01';""" + qt_test_not_in """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t2.d_int not in (10,20)""" + qt_test_in """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t2.d_int in (10,20)""" + qt_test_func_not_in """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where abs(t2.d_int) not in (10,20)""" + qt_test_like """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_char100=t2.d_char100 where t2.d_char100 like '012%'""" + qt_test_like_not """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_char100=t2.d_char100 where t2.d_char100 not like '012%'""" + qt_test_like_to_equal """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_char100=t2.d_char100 where t2.d_char100 like '012'""" + qt_test_func_not_in_and_func_equal_condition """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on abs(t1.d_int)=abs(t2.d_int) where abs(t2.d_int) not in (10,20)""" + + qt_test_between_and """explain shape plan + select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and t1.a between 1 and 10;""" + qt_test_and """explain shape plan + select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and (t1.a >=2 and t1.a<=10);""" + qt_test_or1 """explain shape plan + select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and not t1.a between 2 and 10;""" + qt_test_or2 """explain shape plan + select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and not (t1.a >=2 and t1.a<=10);""" + qt_test_sign_predicate """explain shape plan + select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and sign(t1.a)>=1""" + qt_test_if_predicate """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int + where case when t2.d_int not in (10,20) then true else false end""" + qt_test_if_and_in_predicate """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int + where if(t2.d_int =5,true, false) not in (FALSE)""" + qt_test_if_and_in_predicate_not """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int + where if(t2.d_int =5,true, false) !=FALSE""" + qt_test_multi_slot_in_predicate1 """explain shape plan + select * from extend_infer_t3 t1 inner join extend_infer_t4 t2 on t1.a+t1.c=t2.a+t2.c and t1.a+t1.c<10""" + qt_test_multi_slot_in_predicate2 """explain shape plan + select * from extend_infer_t3 t1 inner join extend_infer_t4 t2 on t1.a=t2.a and t1.b=t2.b and t1.a+t1.b<10""" + qt_test_case_when_predicate """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int + where case when t2.d_int=1 then true when t2.d_int=2 then false else false end""" + qt_test_datetimev2_predicate """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_datetimev2=t2.d_datetimev2 where convert_tz(date_trunc(t2.d_datetimev2, 'month'),'Asia/Shanghai','Europe/Paris')='2024-01-01';""" + + // function predicate + qt_test_convert_tz_predicate """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on t1.d_datetimev2 =t2.d_datetimev2 and convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris')>'2022-01-01';""" + qt_test_next_date_predicate """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on t1.d_datetimev2 =t2.d_datetimev2 and day(hours_add(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'),10))>10;""" + qt_test_random_nest_predicate """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on t1.d_datetimev2 =t2.d_datetimev2 and day(hours_add(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'),random(1,10)))>10;""" + qt_test_random_predicate """explain shape plan + select * from extend_infer_t3 t1 inner join extend_infer_t4 t2 on t1.a=t2.a and t1.a>random(10);""" + qt_test_predicate_map """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on t1.d_datetimev2 =t2.d_datetimev2 and day(hours_add(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'),10))>10 + and convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris') < '2022-01-01';""" + + // test cast + qt_test_int_upcast """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_tinyint where t2.d_tinyint<10;""" + qt_test_int_downcast """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_int as tinyint)=t2.d_tinyint where t2.d_tinyint<10;""" + qt_test_date_upcast """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 on t1.d_datev2 =t2.d_datetimev2 and t1.d_datev2<'2022-01-03';""" + qt_test_date_downcast """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 on t1.d_datev2 =cast(t2.d_datetimev2 as datev2) and t1.d_datev2<'2022-01-03';""" + qt_test_date_both_upcast1 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 on cast(t1.d_datev2 as datetimev2)=cast(t2.d_date as datetimev2) + and t1.d_datev2<'2022-01-03';""" + qt_test_date_both_upcast2 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 on cast(t1.d_datetime as datetimev2)=cast(t2.d_date as datetimev2) + and t1.d_datetime<'2022-01-03';""" + // cast char behave differently because of substring + qt_test_char_different_type1 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_char100=t2.d_char10 and t2.d_char10>'abc';""" + qt_test_char_different_type2 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_char100 as char(50))=t2.d_char10 and t2.d_char10>'abc';""" + qt_test_char_different_type3 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_char100 as char(50))=cast(t2.d_char10 as char(50)) and t2.d_char10>'abc';""" + qt_test_char_different_type4 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_char100 as char(200))=cast(t2.d_char10 as char(200)) and t2.d_char10>'abc';""" + + qt_test_cast_and_func """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on abs(t1.d_int)=t2.d_tinyint where t2.d_tinyint<10 ;""" + qt_test_cast_and_func2 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(abs(t1.d_int) as tinyint)=t2.d_tinyint where t2.d_tinyint<10;""" + // this should be inferred but not + qt_test_cast_and_func3 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_int as tinyint)=abs(t2.d_tinyint) where abs(t2.d_tinyint)<10;""" + qt_test_cast_and_func4 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int =abs(t2.d_tinyint) where abs(t2.d_tinyint)<10;""" + qt_test_func_equal_and_nest_func_pred1 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris') =convert_tz(t2.d_datetimev2,'Asia/Shanghai','Europe/Paris') + and day(hours_add(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'),10))>10;""" + qt_test_func_equal_and_nest_func_pred2 """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris') =convert_tz(t2.d_datetimev2,'Asia/Shanghai','Europe/Paris') + and day(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'))>10;""" + qt_predicate_to_empty_relation """explain shape plan + select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a and t2.a=1 left join extend_infer_t4 t3 on t1.a=t3.a where t1.a=2""" + qt_equal_table_predicate_delete """ + explain shape plan select * from extend_infer_t3 where a=1 and c=1; + """ + + qt_test_integer_cast_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_tinyint where t1.d_tinyint<10 order by t1.d_int;;""" + qt_test_simple_compare_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t2.d_int<10 order by t1.d_int;""" + qt_test_simple_compare_not_equal_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t2.d_int!=10 order by t1.d_int;""" + qt_test_simple_compare_datetimev2_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_datetimev2=t2.d_datetimev2 where t2.d_datetimev2='2024-01-01' order by t1.d_int;;""" + qt_test_simple_compare_not_equal_datetimev2_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_datetimev2=t2.d_datetimev2 where t2.d_datetimev2!='2024-01-01' order by t1.d_int;;""" + qt_test_not_in_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t2.d_int not in (10,20) order by t1.d_int;""" + qt_test_in_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t2.d_int in (10,20) order by t1.d_int ;""" + qt_test_func_not_in_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where abs(t2.d_int) not in (10,20) order by t1.d_int;""" + qt_test_like_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_char100=t2.d_char100 where t2.d_char100 like '012% order by t1.d_int;'""" + qt_test_like_not_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_char100=t2.d_char100 where t2.d_char100 not like '012%' order by t1.d_int;""" + qt_test_like_to_equal_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_char100=t2.d_char100 where t2.d_char100 like '012' order by t1.d_int;""" + qt_test_func_not_in_and_func_equal_condition_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on abs(t1.d_int)=abs(t2.d_int) where abs(t2.d_int) not in (10,20) order by t1.d_int;""" + + qt_test_between_and_res """select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and t1.a between 1 and 10 order by 1,2,3,4,5,6,7,8;""" + qt_test_and_res """select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and (t1.a >=2 and t1.a<=10) order by 1,2,3,4,5,6,7,8;""" + qt_test_or1_res """select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and not t1.a between 2 and 10 order by 1,2,3,4,5,6,7,8;""" + qt_test_or2_res """select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and not (t1.a >=2 and t1.a<=10) order by 1,2,3,4,5,6,7,8;""" + qt_test_sign_predicate_res """select * from extend_infer_t3 t1 ,extend_infer_t4 t2 where t1.a=t2.a and sign(t1.a)>=1 order by 1,2,3,4,5,6,7,8""" + qt_test_if_predicate_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int + where case when t2.d_int not in (10,20) then true else false end order by 1,2,3,4,5,6,7,8""" + qt_test_if_and_in_predicate_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int + where if(t2.d_int =5,true, false) not in (FALSE) order by 1,2,3,4,5,6,7,8""" + qt_test_if_and_in_predicate_not_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int + where if(t2.d_int =5,true, false) !=FALSE order by 1,2,3,4,5,6,7,8""" + qt_test_multi_slot_in_predicate1_res """select * from extend_infer_t3 t1 inner join extend_infer_t4 t2 on t1.a+t1.c=t2.a+t2.c and t1.a+t1.c<10 order by 1,2,3,4,5,6,7,8""" + qt_test_multi_slot_in_predicate2_res """select * from extend_infer_t3 t1 inner join extend_infer_t4 t2 on t1.a=t2.a and t1.b=t2.b and t1.a+t1.b<10 order by 1,2,3,4,5,6,7,8""" + qt_test_case_when_predicate_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int + where case when t2.d_int=1 then true when t2.d_int=2 then false else false end order by t1.d_int""" + qt_test_datetimev2_predicate_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_datetimev2=t2.d_datetimev2 where convert_tz(date_trunc(t2.d_datetimev2, 'month'),'Asia/Shanghai','Europe/Paris')='2024-01-01' order by t1.d_int;""" + + // function predicate + qt_test_convert_tz_predicate_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on t1.d_datetimev2 =t2.d_datetimev2 and convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris')>'2022-01-01' order by t1.d_int;""" + qt_test_next_date_predicate_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on t1.d_datetimev2 =t2.d_datetimev2 and day(hours_add(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'),10))>10 order by t1.d_int;""" + qt_test_random_nest_predicate_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on t1.d_datetimev2 =t2.d_datetimev2 and day(hours_add(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'),random(1,10)))>10 order by t1.d_int;""" + qt_test_random_predicate_res """select * from extend_infer_t3 t1 inner join extend_infer_t4 t2 on t1.a=t2.a and t1.a>random(10) order by 1,2,3,4,5,6,7,8;""" + qt_test_predicate_map_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on t1.d_datetimev2 =t2.d_datetimev2 and day(hours_add(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'),10))>10 + and convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris') < '2022-01-01' order by t1.d_int;""" + + // test cast + qt_test_int_upcast_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_tinyint where t2.d_tinyint<10 order by t1.d_int;""" + qt_test_int_downcast_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_int as tinyint)=t2.d_tinyint where t2.d_tinyint<10 order by t1.d_int;""" + qt_test_date_upcast_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 on t1.d_datev2 =t2.d_datetimev2 and t1.d_datev2<'2022-01-03' order by t1.d_int;""" + qt_test_date_downcast_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 on t1.d_datev2 =cast(t2.d_datetimev2 as datev2) and t1.d_datev2<'2022-01-03' order by t1.d_int;""" + qt_test_date_both_upcast1_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 on cast(t1.d_datev2 as datetimev2)=cast(t2.d_date as datetimev2) + and t1.d_datev2<'2022-01-03' order by t1.d_int;""" + qt_test_date_both_upcast2_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 on cast(t1.d_datetime as datetimev2)=cast(t2.d_date as datetimev2) + and t1.d_datetime<'2022-01-03' order by t1.d_int;""" + // cast char behave differently because of substring + qt_test_char_different_type1_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_char100=t2.d_char10 and t2.d_char10>'abc' order by t1.d_int;""" + qt_test_char_different_type2_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_char100 as char(50))=t2.d_char10 and t2.d_char10>'abc' order by t1.d_int;""" + qt_test_char_different_type3_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_char100 as char(50))=cast(t2.d_char10 as char(50)) and t2.d_char10>'abc' order by t1.d_int;""" + qt_test_char_different_type4_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_char100 as char(200))=cast(t2.d_char10 as char(200)) and t2.d_char10>'abc' order by t1.d_int;""" + + qt_test_cast_and_func_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on abs(t1.d_int)=t2.d_tinyint where t2.d_tinyint<10 order by t1.d_int;""" + qt_test_cast_and_func2_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(abs(t1.d_int) as tinyint)=t2.d_tinyint where t2.d_tinyint<10 order by t1.d_int;""" + qt_test_cast_and_func3_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on cast(t1.d_int as tinyint)=abs(t2.d_tinyint) where abs(t2.d_tinyint)<10 order by t1.d_int;""" + qt_test_cast_and_func4_res """select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int =abs(t2.d_tinyint) where abs(t2.d_tinyint)<10 order by t1.d_int;""" + qt_test_func_equal_and_nest_func_pred1_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris') =convert_tz(t2.d_datetimev2,'Asia/Shanghai','Europe/Paris') + and day(hours_add(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'),10))>10 order by t1.d_int;""" + qt_test_func_equal_and_nest_func_pred2_res """select * from extend_infer_t1 t1 inner join extend_infer_t2 t2 + on convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris') =convert_tz(t2.d_datetimev2,'Asia/Shanghai','Europe/Paris') + and day(convert_tz(t1.d_datetimev2,'Asia/Shanghai','Europe/Paris'))>10 order by t1.d_int;""" + qt_predicate_to_empty_relation_res """select * from extend_infer_t3 t1 left join extend_infer_t4 t2 on t1.a=t2.a and t2.a=1 left join extend_infer_t4 t3 on t1.a=t3.a where t1.a=2""" + qt_equal_table_predicate_delete_res """select * from extend_infer_t3 where a=1 and c=1 order by 1,2,3,4;""" + + // non-inner join + qt_not_equal_inner_left """explain shape plan + select * from extend_infer_t1 t3 inner join ( + select t1.d_int as c1 from extend_infer_t1 t1 left join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1!=10;""" + qt_not_equal_inner_left2 """explain shape plan + select * from extend_infer_t1 t3 inner join ( + select t2.d_int as c1 from extend_infer_t1 t1 left join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1!=10;""" + qt_not_equal_left_inner """explain shape plan + select * from extend_infer_t1 t3 left join ( + select t1.d_int as c1 from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1!=10;""" + qt_not_equal_left_left """explain shape plan + select * from extend_infer_t1 t3 left join ( + select t1.d_int as c1 from extend_infer_t1 t1 left join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1!=10;""" + qt_not_equal_left_left2 """explain shape plan + select * from extend_infer_t1 t3 left join ( + select t2.d_int as c1 from extend_infer_t1 t1 left join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1!=10;""" + + qt_not_in_inner_right """explain shape plan + select * from extend_infer_t1 t3 inner join ( + select t1.d_int as c1 from extend_infer_t1 t1 right join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1 not in (10,20);""" + qt_not_in_inner_right2 """explain shape plan + select * from extend_infer_t1 t3 inner join ( + select t2.d_int as c1 from extend_infer_t1 t1 right join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1 not in (10,20);""" + qt_not_in_right_inner """explain shape plan + select * from extend_infer_t1 t3 right join ( + select t1.d_int as c1 from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1 not in (10,20);""" + qt_not_in_right_right """explain shape plan + select * from extend_infer_t1 t3 right join ( + select t1.d_int as c1 from extend_infer_t1 t1 right join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1 not in (10,20);""" + qt_not_in_right_right2 """explain shape plan + select * from extend_infer_t1 t3 right join ( + select t2.d_int as c1 from extend_infer_t1 t1 right join extend_infer_t1 t2 on t1.d_int=t2.d_int) t on t3.d_int=t.c1 where t.c1 not in (10,20);""" + + qt_not_equal_semi_semi_with_cast """explain shape plan + select * from extend_infer_t1 t3 left semi join ( + select t1.d_int as c1 from extend_infer_t1 t1 left semi join extend_infer_t1 t2 on t1.d_int=t2.d_tinyint) t + on t3.d_smallint=t.c1 where t3.d_smallint !=10;""" + qt_not_equal_anti_anti_with_cast """explain shape plan + select * from extend_infer_t1 t3 left anti join ( + select t1.d_int as c1 from extend_infer_t1 t1 left anti join extend_infer_t1 t2 on t1.d_int=t2.d_tinyint) t + on t3.d_smallint=t.c1 where t3.d_smallint !=10;""" + qt_not_equal_anti_left_with_cast """explain shape plan + select * from extend_infer_t1 t3 left anti join ( + select t1.d_int as c1 from extend_infer_t1 t1 left join extend_infer_t1 t2 on t1.d_int=t2.d_tinyint) t + on t3.d_smallint=t.c1 where t3.d_smallint !=10;""" + qt_not_equal_semi_anti_with_cast """explain shape plan + select * from extend_infer_t1 t3 left semi join ( + select t1.d_int as c1 from extend_infer_t1 t1 left anti join extend_infer_t1 t2 on t1.d_int=t2.d_tinyint) t + on t3.d_smallint=t.c1 where t3.d_smallint !=10;""" + qt_in_subquery_to_semi_join """explain shape plan + select * from extend_infer_t1 t1 where t1.d_int in (select d_int from extend_infer_t2 where d_int != 10) + """ + // should not infer + qt_not_in_subquery_to_na_anti_join_not_infer """explain shape plan + select * from extend_infer_t1 t1 where t1.d_int not in (select d_int from extend_infer_t2 ) and t1.d_int !=10 + """ + qt_in_subquery_to_semi_join """explain shape plan + select * from extend_infer_t1 t1 inner join extend_infer_t1 t2 on t1.d_int=t2.d_int where t1.d_int in (select d_int from extend_infer_t2 where d_int != 10) + """ + + qt_cast_to_decimal_overflow_not_infer """explain shape plan + select 1 from extend_infer_t1 t1 inner join extend_infer_t2 t2 on t1.d_tinyint=t2.d_int and t1.d_tinyint in(0.5,0.1)""" + qt_char_equal_int_infer """explain shape plan + select 1 from extend_infer_t1 t1 inner join extend_infer_t2 t2 on t1.d_char10=t2.d_int and t1.d_char10 in('d','bb')""" + qt_date_equal_int_infer """explain shape plan + select 1 from extend_infer_t1 t1 inner join extend_infer_t2 t2 on t1.d_datev2=t2.d_int and t1.d_datev2 in('2024-01-01','2024-01-02')""" + +} \ No newline at end of file diff --git a/regression-test/suites/nereids_rules_p0/infer_predicate/infer_unequal_predicates.groovy b/regression-test/suites/nereids_rules_p0/infer_predicate/infer_unequal_predicates.groovy new file mode 100644 index 00000000000000..23eafac414b799 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/infer_predicate/infer_unequal_predicates.groovy @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("infer_unequal_predicates") { + sql """SET ignore_shape_nodes='PhysicalDistribute,PhysicalProject'""" + sql "set runtime_filter_mode = OFF" + sql "set disable_join_reorder=true " + sql "drop table if exists infer_unequal_predicates_t1" + sql """ + CREATE TABLE `infer_unequal_predicates_t1` ( + `a` INT NULL, + `b` VARCHAR(10) NULL, + `c` INT NULL, + `d` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`a`, `b`) + partition by list(d) + (partition p1 values in (5,6), + partition p2 values in (7,8)) + DISTRIBUTED BY RANDOM BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql "insert into infer_unequal_predicates_t1 values(1,'d2',3,5);" + sql "insert into infer_unequal_predicates_t1 values(0,'d2',3,5);" + sql "insert into infer_unequal_predicates_t1 values(0,'d2',3,7);" + + sql "drop table if exists infer_unequal_predicates_t2" + sql """ + CREATE TABLE `infer_unequal_predicates_t2` ( + `a` INT NULL, + `b` VARCHAR(10) NULL, + `c` INT NULL, + `d` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`a`, `b`) + partition by list(d) + (partition p1 values in (5,6), + partition p2 values in (7,8)) + DISTRIBUTED BY RANDOM BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql "insert into infer_unequal_predicates_t2 values(1,'d2',3,5);" + sql "insert into infer_unequal_predicates_t2 values(0,'d2',3,5);" + sql "insert into infer_unequal_predicates_t2 values(0,'d2',3,7);" + + sql "drop table if exists infer_unequal_predicates_t3" + sql """ + create table infer_unequal_predicates_t3(d_int int, d_char100 char(100), d_smallint smallint, d_tinyint tinyint, d_char10 char(10),d_datetimev2 datetimev2, d_datev2 datev2,d_date date, d_datetime datetime) properties('replication_num'='1'); + """ + sql """ + insert into infer_unequal_predicates_t3 values(1,'01234567890123456789', 3,3,'0123456789','2020-01-09 10:00:00.99','2020-01-09','2022-08-09','2022-08-09 10:00:00'),(14,'01234567890123456789', 33,23,'0123456789','2020-01-11 10:00:00.99','2020-01-11','2022-08-03','2022-08-09 10:00:02'); + """ + + // c c<1 should not be inferred + qt_not_infer_same_table_have_mid_column """ + explain shape plan + SELECT * FROM infer_unequal_predicates_t1 t1 WHERE t1.c c t1.a t1.c<1 should be inferred + qt_infer_diff_table """explain shape plan + SELECT * FROM infer_unequal_predicates_t1 t1 INNER JOIN infer_unequal_predicates_t2 t2 ON t2.a<1 and t1.c a<1 should be inferred + qt_should_infer_because_a_is_key """ + explain shape plan + SELECT * FROM infer_unequal_predicates_t1 t1 WHERE t1.a d<1 should be inferred + qt_should_infer_because_d_is_partition_column """ + explain shape plan + SELECT * FROM infer_unequal_predicates_t1 t1 WHERE t1.d t2.c<1 should be inferred + qt_infer_with_equal """explain shape plan + SELECT * FROM infer_unequal_predicates_t1 t1 INNER JOIN infer_unequal_predicates_t2 t2 ON t1.a<1 and t1.a=t2.c""" + + // t2.c<1, t1.a t1.a<1 and t2.a<1 should be inferred + qt_infer_4_expr """explain shape plan + SELECT * FROM infer_unequal_predicates_t1 t1 INNER JOIN infer_unequal_predicates_t2 t2 ON t2.c<1 and t1.a1 AND t1.a=t1.c + """ + qt_infer_long_chain_diff_table """ + explain shape plan + SELECT * FROM infer_unequal_predicates_t1 t1 INNER JOIN infer_unequal_predicates_t2 t2 ON t1.a1 AND t1.a=t2.c AND t2.ct2.c + """ + + qt_no_infer_cyclic_dependency """ + explain shape plan + SELECT * FROM infer_unequal_predicates_t1 t1 INNER JOIN infer_unequal_predicates_t2 t2 ON t1.at2.d_smallint and t2.d_smallint >1; + """ + + qt_multi_slot_equal """explain shape plan select * from infer_unequal_predicates_t1 where a=c and c=d""" + + qt_no_redundant_predicates """ + explain shape plan + SELECT t1.a FROM (select * from infer_unequal_predicates_t1 t1 where t1.d<10 and t1.d=t1.c and t1.c<10) t1 inner join + infer_unequal_predicates_t2 t2 on t1.d=t2.d where t2.d>1 + """ + + // TODO + // Non equivalent transfer relation derivation, expression is not supported temporarily + qt_expr_unequal_infer_same_table1 """explain shape plan + select * from infer_unequal_predicates_t1 t1 where abs(t1.d) cast(k2 as bigint));" - contains "partitions=2/3 (p2,p3)" + contains "partitions=1/3 (p2)" } //fix BUG: p2 missed