Skip to content

Commit

Permalink
[fix](Nereids): clone the producer plan and put logicalAnchor generat…
Browse files Browse the repository at this point in the history
…ed by `Or_Expansion` above `logicalSink` (#34771)

* put cte anchor on the root

put logicalAnchor on root

clone plan of cte consumer

* fix unit test
  • Loading branch information
keanji-x authored and Doris-Extras committed May 14, 2024
1 parent 5ece07a commit 0deb629
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(List<RewriteJob> jobs) {
custom(RuleType.REWRITE_CTE_CHILDREN, () -> new RewriteCteChildren(jobs))
),
topic("or expansion",
topDown(new OrExpansion())),
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)),
topic("whole plan check",
custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.rules.rewrite.OrExpansion.OrExpandsionContext;
import org.apache.doris.nereids.trees.copier.DeepCopierContext;
import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
Expand All @@ -38,8 +40,11 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.qe.ConnectContext;
Expand All @@ -53,6 +58,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

Expand All @@ -62,7 +68,7 @@
* => / \
* HJ(cond1) HJ(cond2 and !cond1)
*/
public class OrExpansion extends OneExplorationRuleFactory {
public class OrExpansion extends DefaultPlanRewriter<OrExpandsionContext> implements CustomRewriter {
public static final OrExpansion INSTANCE = new OrExpansion();
public static final ImmutableSet<JoinType> supportJoinType = new ImmutableSet
.Builder<JoinType>()
Expand All @@ -73,63 +79,101 @@ public class OrExpansion extends OneExplorationRuleFactory {
.build();

@Override
public Rule build() {
return logicalJoin(any(), any()).when(JoinUtils::shouldNestedLoopJoin)
.whenNot(LogicalJoin::isMarkJoin)
.when(join -> supportJoinType.contains(join.getJoinType())
&& ConnectContext.get().getSessionVariable().getEnablePipelineEngine())
.thenApply(ctx -> {
LogicalJoin<? extends Plan, ? extends Plan> join = ctx.root;
Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
"Only Expansion nest loop join without hashCond");
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
OrExpandsionContext ctx = new OrExpandsionContext(
jobContext.getCascadesContext().getStatementContext(), jobContext.getCascadesContext());
plan = plan.accept(this, ctx);
for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
LogicalCTEProducer<? extends Plan> producer = ctx.cteProducerList.get(i);
plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
}
return plan;
}

@Override
public Plan visit(Plan plan, OrExpandsionContext ctx) {
List<Plan> newChildren = new ArrayList<>();
boolean hasNewChildren = false;
for (Plan child : plan.children()) {
Plan newChild = child.accept(this, ctx);
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? plan.withChildren(newChildren) : plan;
}

//1. Try to split or conditions
Pair<List<Expression>, List<Expression>> hashOtherConditions = splitOrCondition(join);
if (hashOtherConditions == null || hashOtherConditions.first.size() <= 1) {
return join;
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, OrExpandsionContext ctx) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) this.visit(join, ctx);
if (join.isMarkJoin() || !JoinUtils.shouldNestedLoopJoin(join)) {
return join;
}
if (!(supportJoinType.contains(join.getJoinType())
&& ConnectContext.get().getSessionVariable().getEnablePipelineEngine())) {
return join;
}
Preconditions.checkArgument(join.getHashJoinConjuncts().isEmpty(),
"Only Expansion nest loop join without hashCond");

//2. Construct CTE with the children
LogicalCTEProducer<? extends Plan> leftProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), join.left());
LogicalCTEProducer<? extends Plan> rightProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), join.right());
List<Plan> joins = new ArrayList<>();
//1. Try to split or conditions
Pair<List<Expression>, List<Expression>> hashOtherConditions = splitOrCondition(join);
if (hashOtherConditions == null || hashOtherConditions.first.size() <= 1) {
return join;
}

// 3. Expand join to hash join with CTE
if (join.getJoinType().isInnerJoin()) {
joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer));
} else if (join.getJoinType().isOuterJoin()) {
// left outer join = inner join union left anti join
joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer));
joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions, join, leftProducer, rightProducer));
if (join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) {
// full outer join = inner join union left anti join union right anti join
joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions, join, rightProducer, leftProducer));
}
} else if (join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) {
joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions, join, leftProducer, rightProducer));
} else {
throw new RuntimeException("or-expansion is not supported for " + join);
}
//2. Construct CTE with the children
LogicalPlan leftClone = LogicalPlanDeepCopier.INSTANCE
.deepCopy((LogicalPlan) join.left(), new DeepCopierContext());
LogicalCTEProducer<? extends Plan> leftProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), leftClone);
LogicalPlan rightClone = LogicalPlanDeepCopier.INSTANCE
.deepCopy((LogicalPlan) join.right(), new DeepCopierContext());
LogicalCTEProducer<? extends Plan> rightProducer = new LogicalCTEProducer<>(
ctx.statementContext.getNextCTEId(), rightClone);
Map<Slot, Slot> leftCloneToLeft = new HashMap<>();
for (int i = 0; i < leftClone.getOutput().size(); i++) {
leftCloneToLeft.put(leftClone.getOutput().get(i), (join.left()).getOutput().get(i));
}
Map<Slot, Slot> rightCloneToRight = new HashMap<>();
for (int i = 0; i < rightClone.getOutput().size(); i++) {
rightCloneToRight.put(rightClone.getOutput().get(i), (join.right()).getOutput().get(i));
}

//4. union all joins and construct LogicalCTEAnchor with CTEs
List<List<SlotReference>> childrenOutputs = joins.stream()
.map(j -> j.getOutput().stream()
.map(SlotReference.class::cast)
.collect(ImmutableList.toImmutableList()))
.collect(ImmutableList.toImmutableList());
LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()),
childrenOutputs, ImmutableList.of(), false, joins);
LogicalCTEAnchor<? extends Plan, ? extends Plan> intermediateAnchor = new LogicalCTEAnchor<>(
rightProducer.getCteId(), rightProducer, union);
return new LogicalCTEAnchor<Plan, Plan>(leftProducer.getCteId(), leftProducer, intermediateAnchor);
}).toRule(RuleType.OR_EXPANSION);
// 3. Expand join to hash join with CTE
List<Plan> joins = new ArrayList<>();
if (join.getJoinType().isInnerJoin()) {
joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
} else if (join.getJoinType().isOuterJoin()) {
// left outer join = inner join union left anti join
joins.addAll(expandInnerJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
joins.add(expandLeftAntiJoin(ctx.cascadesContext,
hashOtherConditions, join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
if (join.getJoinType().equals(JoinType.FULL_OUTER_JOIN)) {
// full outer join = inner join union left anti join union right anti join
joins.add(expandLeftAntiJoin(ctx.cascadesContext, hashOtherConditions,
join, rightProducer, leftProducer, rightCloneToRight, leftCloneToLeft));
}
} else if (join.getJoinType().equals(JoinType.LEFT_ANTI_JOIN)) {
joins.add(expandLeftAntiJoin(ctx.cascadesContext, hashOtherConditions,
join, leftProducer, rightProducer, leftCloneToLeft, rightCloneToRight));
} else {
throw new RuntimeException("or-expansion is not supported for " + join);
}
//4. union all joins and put producers to context
List<List<SlotReference>> childrenOutputs = joins.stream()
.map(j -> j.getOutput().stream()
.map(SlotReference.class::cast)
.collect(ImmutableList.toImmutableList()))
.collect(ImmutableList.toImmutableList());
LogicalUnion union = new LogicalUnion(Qualifier.ALL, new ArrayList<>(join.getOutput()),
childrenOutputs, ImmutableList.of(), false, joins);
ctx.cteProducerList.add(leftProducer);
ctx.cteProducerList.add(rightProducer);
return union;
}

// try to find a condition that can be split into hash conditions
Expand All @@ -150,6 +194,18 @@ public Rule build() {
return null;
}

private Map<Slot, Slot> constructReplaceMap(LogicalCTEConsumer leftConsumer, Map<Slot, Slot> leftCloneToLeft,
LogicalCTEConsumer rightConsumer, Map<Slot, Slot> rightCloneToRight) {
Map<Slot, Slot> replaced = new HashMap<>();
for (Entry<Slot, Slot> entry : leftConsumer.getProducerToConsumerOutputMap().entrySet()) {
replaced.put(leftCloneToLeft.get(entry.getKey()), entry.getValue());
}
for (Entry<Slot, Slot> entry : rightConsumer.getProducerToConsumerOutputMap().entrySet()) {
replaced.put(rightCloneToRight.get(entry.getKey()), entry.getValue());
}
return replaced;
}

// expand Anti Join:
// Left Anti join cond1 or cond2, other Left Anti join cond1 and other
// / \ / \
Expand All @@ -160,16 +216,16 @@ private Plan expandLeftAntiJoin(CascadesContext ctx,
Pair<List<Expression>, List<Expression>> hashOtherConditions,
LogicalJoin<? extends Plan, ? extends Plan> originJoin,
LogicalCTEProducer<? extends Plan> leftProducer,
LogicalCTEProducer<? extends org.apache.doris.nereids.trees.plans.Plan> rightProducer) {
LogicalCTEProducer<? extends org.apache.doris.nereids.trees.plans.Plan> rightProducer,
Map<Slot, Slot> leftCloneToLeft, Map<Slot, Slot> rightCloneToRight) {
LogicalCTEConsumer left = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
leftProducer.getCteId(), "", leftProducer);
LogicalCTEConsumer right = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(),
rightProducer.getCteId(), "", rightProducer);
ctx.putCTEIdToConsumer(left);
ctx.putCTEIdToConsumer(right);

Map<Slot, Slot> replaced = new HashMap<>(left.getProducerToConsumerOutputMap());
replaced.putAll(right.getProducerToConsumerOutputMap());
Map<Slot, Slot> replaced = constructReplaceMap(left, leftCloneToLeft, right, rightCloneToRight);
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
List<Expression> newOtherConditions = otherConditions.stream()
Expand All @@ -191,8 +247,7 @@ private Plan expandLeftAntiJoin(CascadesContext ctx,
LogicalCTEConsumer newRight = new LogicalCTEConsumer(
ctx.getStatementContext().getNextRelationId(), rightProducer.getCteId(), "", rightProducer);
ctx.putCTEIdToConsumer(newRight);
Map<Slot, Slot> newReplaced = new HashMap<>(left.getProducerToConsumerOutputMap());
newReplaced.putAll(newRight.getProducerToConsumerOutputMap());
Map<Slot, Slot> newReplaced = constructReplaceMap(left, leftCloneToLeft, newRight, rightCloneToRight);
newOtherConditions = otherConditions.stream()
.map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s))
.collect(Collectors.toList());
Expand Down Expand Up @@ -224,7 +279,8 @@ private Plan expandLeftAntiJoin(CascadesContext ctx,
private List<Plan> expandInnerJoin(CascadesContext ctx, Pair<List<Expression>,
List<Expression>> hashOtherConditions,
LogicalJoin<? extends Plan, ? extends Plan> join, LogicalCTEProducer<? extends Plan> leftProducer,
LogicalCTEProducer<? extends Plan> rightProducer) {
LogicalCTEProducer<? extends Plan> rightProducer,
Map<Slot, Slot> leftCloneToLeft, Map<Slot, Slot> rightCloneToRight) {
List<Expression> disjunctions = hashOtherConditions.first;
List<Expression> otherConditions = hashOtherConditions.second;
// For null values, equalTo and not equalTo both return false
Expand All @@ -248,8 +304,7 @@ private List<Plan> expandInnerJoin(CascadesContext ctx, Pair<List<Expression>,
ctx.putCTEIdToConsumer(right);

//rewrite conjuncts to replace the old slots with CTE slots
Map<Slot, Slot> replaced = new HashMap<>(left.getProducerToConsumerOutputMap());
replaced.putAll(right.getProducerToConsumerOutputMap());
Map<Slot, Slot> replaced = constructReplaceMap(left, leftCloneToLeft, right, rightCloneToRight);
List<Expression> hashCond = pair.first.stream()
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
.collect(Collectors.toList());
Expand Down Expand Up @@ -283,4 +338,16 @@ private Pair<List<Expression>, List<Expression>> extractHashAndOtherConditions(i
}
return Pair.of(Lists.newArrayList(equal.get(hashCondIdx)), others);
}

class OrExpandsionContext {
List<LogicalCTEProducer<? extends Plan>> cteProducerList;
StatementContext statementContext;
CascadesContext cascadesContext;

public OrExpandsionContext(StatementContext statementContext, CascadesContext cascadesContext) {
this.statementContext = statementContext;
this.cteProducerList = new ArrayList<>();
this.cascadesContext = cascadesContext;
}
}
}
Loading

0 comments on commit 0deb629

Please sign in to comment.