diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 422667af6cfbae..b18f8a67a3a8ab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -471,7 +471,7 @@ private static List getWholeTreeRewriteJobs(List jobs) { custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(jobs)) ), topic("or expansion", - topDown(new OrExpansion())), + custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)), topic("whole plan check", custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new) ) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java index ff531ffce3880f..9f9257f5f60a41 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java @@ -19,10 +19,12 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext; +import org.apache.doris.nereids.rules.rewrite.OrExpansion.OrExpandsionContext; +import org.apache.doris.nereids.trees.copier.DeepCopierContext; +import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; @@ -38,8 +40,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.qe.ConnectContext; @@ -53,6 +58,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -62,7 +68,7 @@ * => / \ * HJ(cond1) HJ(cond2 and !cond1) */ -public class OrExpansion extends OneExplorationRuleFactory { +public class OrExpansion extends DefaultPlanRewriter implements CustomRewriter { public static final OrExpansion INSTANCE = new OrExpansion(); public static final ImmutableSet supportJoinType = new ImmutableSet .Builder() @@ -73,63 +79,101 @@ public class OrExpansion extends OneExplorationRuleFactory { .build(); @Override - public Rule build() { - return logicalJoin(any(), any()).when(JoinUtils::shouldNestedLoopJoin) - .whenNot(LogicalJoin::isMarkJoin) - .when(join -> supportJoinType.contains(join.getJoinType()) - && ConnectContext.get().getSessionVariable().getEnablePipelineEngine()) - .thenApply(ctx -> { - LogicalJoin join = ctx.root; - Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(), - "Only Expansion nest loop join without hashCond"); + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + OrExpandsionContext ctx = new OrExpandsionContext( + jobContext.getCascadesContext().getStatementContext(), jobContext.getCascadesContext()); + plan = plan.accept(this, ctx); + for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) { + LogicalCTEProducer producer = ctx.cteProducerList.get(i); + plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan); + } + return plan; + } + + @Override + public Plan visit(Plan plan, OrExpandsionContext ctx) { + List newChildren = new ArrayList<>(); + boolean hasNewChildren = false; + for (Plan child : plan.children()) { + Plan newChild = child.accept(this, ctx); + if (newChild != child) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + return hasNewChildren ? plan.withChildren(newChildren) : plan; + } - //1. Try to split or conditions - Pair, List> hashOtherConditions = splitOrCondition(join); - if (hashOtherConditions == null || hashOtherConditions.first.size() <= 1) { - return join; - } + @Override + public Plan visitLogicalJoin(LogicalJoin join, OrExpandsionContext ctx) { + join = (LogicalJoin) this.visit(join, ctx); + if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) { + return join; + } + if (!(supportJoinType.contains(join.getJoinType()) + && ConnectContext.get().getSessionVariable().getEnablePipelineEngine())) { + return join; + } + Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(), + "Only Expansion nest loop join without hashCond"); - //2. Construct CTE with the children - LogicalCTEProducer leftProducer = new LogicalCTEProducer<>( - ctx.statementContext.getNextCTEId(), join.left()); - LogicalCTEProducer rightProducer = new LogicalCTEProducer<>( - ctx.statementContext.getNextCTEId(), join.right()); - List joins = new ArrayList<>(); + //1. Try to split or conditions + Pair, List> hashOtherConditions = splitOrCondition(join); + if (hashOtherConditions == null || hashOtherConditions.first.size() <= 1) { + return join; + } - // 3. Expand join to hash join with CTE - if (join.getJoinType().isInnerJoin()) { - joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions, - join, leftProducer, rightProducer)); - } else if (join.getJoinType().isOuterJoin()) { - // left outer join = inner join union left anti join - joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions, - join, leftProducer, rightProducer)); - joins.add(expandLeftAntiJoin(ctx.cascadesContext, - hashOtherConditions, join, leftProducer, rightProducer)); - if (join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) { - // full outer join = inner join union left anti join union right anti join - joins.add(expandLeftAntiJoin(ctx.cascadesContext, - hashOtherConditions, join, rightProducer, leftProducer)); - } - } else if (join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) { - joins.add(expandLeftAntiJoin(ctx.cascadesContext, - hashOtherConditions, join, leftProducer, rightProducer)); - } else { - throw new RuntimeException("or-expansion is not supported for " + join); - } + //2. Construct CTE with the children + LogicalPlan leftClone = LogicalPlanDeepCopier.INSTANCE + .deepCopy((LogicalPlan) join.left(), new DeepCopierContext()); + LogicalCTEProducer leftProducer = new LogicalCTEProducer<>( + ctx.statementContext.getNextCTEId(), leftClone); + LogicalPlan rightClone = LogicalPlanDeepCopier.INSTANCE + .deepCopy((LogicalPlan) join.right(), new DeepCopierContext()); + LogicalCTEProducer rightProducer = new LogicalCTEProducer<>( + ctx.statementContext.getNextCTEId(), rightClone); + Map leftCloneToLeft = new HashMap<>(); + for (int i = 0; i < leftClone.getOutput().size(); i++) { + leftCloneToLeft.put(leftClone.getOutput().get(i), (join.left()).getOutput().get(i)); + } + Map rightCloneToRight = new HashMap<>(); + for (int i = 0; i < rightClone.getOutput().size(); i++) { + rightCloneToRight.put(rightClone.getOutput().get(i), (join.right()).getOutput().get(i)); + } - //4. union all joins and construct LogicalCTEAnchor with CTEs - List> childrenOutputs = joins.stream() - .map(j -> j.getOutput().stream() - .map(SlotReference.class::cast) - .collect(ImmutableList.toImmutableList())) - .collect(ImmutableList.toImmutableList()); - LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()), - childrenOutputs, ImmutableList.of(), false, joins); - LogicalCTEAnchor intermediateAnchor = new LogicalCTEAnchor<>( - rightProducer.getCteId(), rightProducer, union); - return new LogicalCTEAnchor(leftProducer.getCteId(), leftProducer, intermediateAnchor); - }).toRule(RuleType.OR_EXPANSION); + // 3. Expand join to hash join with CTE + List joins = new ArrayList<>(); + if (join.getJoinType().isInnerJoin()) { + joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions, + join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight)); + } else if (join.getJoinType().isOuterJoin()) { + // left outer join = inner join union left anti join + joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions, + join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight)); + joins.add(expandLeftAntiJoin(ctx.cascadesContext, + hashOtherConditions, join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight)); + if (join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) { + // full outer join = inner join union left anti join union right anti join + joins.add(expandLeftAntiJoin(ctx.cascadesContext, hashOtherConditions, + join, rightProducer, leftProducer, rightCloneToRight, leftCloneToLeft)); + } + } else if (join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) { + joins.add(expandLeftAntiJoin(ctx.cascadesContext, hashOtherConditions, + join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight)); + } else { + throw new RuntimeException("or-expansion is not supported for " + join); + } + //4. union all joins and put producers to context + List> childrenOutputs = joins.stream() + .map(j -> j.getOutput().stream() + .map(SlotReference.class::cast) + .collect(ImmutableList.toImmutableList())) + .collect(ImmutableList.toImmutableList()); + LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()), + childrenOutputs, ImmutableList.of(), false, joins); + ctx.cteProducerList.add(leftProducer); + ctx.cteProducerList.add(rightProducer); + return union; } // try to find a condition that can be split into hash conditions @@ -150,6 +194,18 @@ public Rule build() { return null; } + private Map constructReplaceMap(LogicalCTEConsumer leftConsumer, Map leftCloneToLeft, + LogicalCTEConsumer rightConsumer, Map rightCloneToRight) { + Map replaced = new HashMap<>(); + for (Entry entry : leftConsumer.getProducerToConsumerOutputMap().entrySet()) { + replaced.put(leftCloneToLeft.get(entry.getKey()), entry.getValue()); + } + for (Entry entry : rightConsumer.getProducerToConsumerOutputMap().entrySet()) { + replaced.put(rightCloneToRight.get(entry.getKey()), entry.getValue()); + } + return replaced; + } + // expand Anti Join: // Left Anti join cond1 or cond2, other Left Anti join cond1 and other // / \ / \ @@ -160,7 +216,8 @@ private Plan expandLeftAntiJoin(CascadesContext ctx, Pair, List> hashOtherConditions, LogicalJoin originJoin, LogicalCTEProducer leftProducer, - LogicalCTEProducer rightProducer) { + LogicalCTEProducer rightProducer, + Map leftCloneToLeft, Map rightCloneToRight) { LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(), leftProducer.getCteId(), "", leftProducer); LogicalCTEConsumer right = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(), @@ -168,8 +225,7 @@ private Plan expandLeftAntiJoin(CascadesContext ctx, ctx.putCTEIdToConsumer(left); ctx.putCTEIdToConsumer(right); - Map replaced = new HashMap<>(left.getProducerToConsumerOutputMap()); - replaced.putAll(right.getProducerToConsumerOutputMap()); + Map replaced = constructReplaceMap(left, leftCloneToLeft, right, rightCloneToRight); List disjunctions = hashOtherConditions.first; List otherConditions = hashOtherConditions.second; List newOtherConditions = otherConditions.stream() @@ -191,8 +247,7 @@ private Plan expandLeftAntiJoin(CascadesContext ctx, LogicalCTEConsumer newRight = new LogicalCTEConsumer( ctx.getStatementContext().getNextRelationId(), rightProducer.getCteId(), "", rightProducer); ctx.putCTEIdToConsumer(newRight); - Map newReplaced = new HashMap<>(left.getProducerToConsumerOutputMap()); - newReplaced.putAll(newRight.getProducerToConsumerOutputMap()); + Map newReplaced = constructReplaceMap(left, leftCloneToLeft, newRight, rightCloneToRight); newOtherConditions = otherConditions.stream() .map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s)) .collect(Collectors.toList()); @@ -224,7 +279,8 @@ private Plan expandLeftAntiJoin(CascadesContext ctx, private List expandInnerJoin(CascadesContext ctx, Pair, List> hashOtherConditions, LogicalJoin join, LogicalCTEProducer leftProducer, - LogicalCTEProducer rightProducer) { + LogicalCTEProducer rightProducer, + Map leftCloneToLeft, Map rightCloneToRight) { List disjunctions = hashOtherConditions.first; List otherConditions = hashOtherConditions.second; // For null values, equalTo and not equalTo both return false @@ -248,8 +304,7 @@ private List expandInnerJoin(CascadesContext ctx, Pair, ctx.putCTEIdToConsumer(right); //rewrite conjuncts to replace the old slots with CTE slots - Map replaced = new HashMap<>(left.getProducerToConsumerOutputMap()); - replaced.putAll(right.getProducerToConsumerOutputMap()); + Map replaced = constructReplaceMap(left, leftCloneToLeft, right, rightCloneToRight); List hashCond = pair.first.stream() .map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s)) .collect(Collectors.toList()); @@ -283,4 +338,16 @@ private Pair, List> extractHashAndOtherConditions(i } return Pair.of(Lists.newArrayList(equal.get(hashCondIdx)), others); } + + class OrExpandsionContext { + List> cteProducerList; + StatementContext statementContext; + CascadesContext cascadesContext; + + public OrExpandsionContext(StatementContext statementContext, CascadesContext cascadesContext) { + this.statementContext = statementContext; + this.cteProducerList = new ArrayList<>(); + this.cascadesContext = cascadesContext; + } + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java new file mode 100644 index 00000000000000..9f8bd8bcc555c4 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrExpansionTest.java @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class OrExpansionTest extends TestWithFeService implements MemoPatternMatchSupported { + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + connectContext.setDatabase("default_cluster:test"); + createTables( + "CREATE TABLE IF NOT EXISTS t1 (\n" + + " id1 int not null,\n" + + " id2 int not null\n" + + ")\n" + + "DUPLICATE KEY(id1)\n" + + "DISTRIBUTED BY HASH(id1) BUCKETS 10\n" + + "PROPERTIES (\"replication_num\" = \"1\")\n", + "CREATE TABLE IF NOT EXISTS t2 (\n" + + " id1 int not null,\n" + + " id2 int not null\n" + + ")\n" + + "DUPLICATE KEY(id1)\n" + + "DISTRIBUTED BY HASH(id2) BUCKETS 10\n" + + "PROPERTIES (\"replication_num\" = \"1\")\n" + ); + } + + @Test + void testOrExpand() { + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + String sql = "select t1.id1 + 1 as id from t1 join t2 on t1.id1 = t2.id1 or t1.id2 = t2.id2"; + Plan plan = PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .printlnTree() + .getPlan(); + Assertions.assertTrue(plan instanceof LogicalCTEAnchor); + Assertions.assertTrue(plan.child(1) instanceof LogicalCTEAnchor); + } + + @Test + void testOrExpandCTE() { + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + connectContext.getSessionVariable().inlineCTEReferencedThreshold = 0; + String sql = "with t3 as (select t1.id1 + 1 as id1, t1.id2 + 2 as id2 from t1), " + + "t4 as (select t2.id1 + 1 as id1, t2.id2 + 2 as id2 from t2) " + + "select t3.id1 from " + + "(select id1, id2 from t3 group by id1, id2) t3 " + + " join " + + "(select id1, id2 from t4 group by id1, id2) t4 " + + "on t3.id1 = t4.id1 or t3.id2 = t4.id2"; + Plan plan = PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .printlnTree() + .getPlan(); + Assertions.assertTrue(plan instanceof LogicalCTEAnchor); + Assertions.assertTrue(plan.child(1) instanceof LogicalCTEAnchor); + Assertions.assertTrue(plan.child(1).child(1) instanceof LogicalCTEAnchor); + Assertions.assertTrue(plan.child(1).child(1).child(1) instanceof LogicalCTEAnchor); + } +}