Skip to content

Commit

Permalink
[fix](Nereids): fix some bugs in or expansion #34840
Browse files Browse the repository at this point in the history
add unit test
  • Loading branch information
keanji-x authored May 14, 2024
1 parent 47f0a67 commit bac5172
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ public Plan visit(Plan plan, OrExpandsionContext ctx) {
return hasNewChildren ? plan.withChildren(newChildren) : plan;
}

@Override
public Plan visitLogicalCTEAnchor(
LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, OrExpandsionContext ctx) {
Plan child1 = this.visit(anchor.child(0), ctx);
// Consumer's CTE must be child of the cteAnchor in this case:
// anchor
// +-producer1
// +-agg(consumer1) join agg(consumer1)
// ------------>
// anchor
// +-producer1
// +-anchor
// +--producer2(agg2(consumer1))
// +--producer3(agg3(consumer1))
// +-consumer2 join consumer3
OrExpandsionContext consumerContext =
new OrExpandsionContext(ctx.statementContext, ctx.cascadesContext);
Plan child2 = this.visit(anchor.child(1), consumerContext);
for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
LogicalCTEProducer<? extends Plan> producer = consumerContext.cteProducerList.get(i);
child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, child2);
}
return anchor.withChildren(ImmutableList.of(child1, child2));
}

@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, OrExpandsionContext ctx) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) this.visit(join, ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,14 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, D
List<Expression> markJoinConjuncts = join.getMarkJoinConjuncts().stream()
.map(c -> ExpressionDeepCopier.INSTANCE.deepCopy(c, context))
.collect(ImmutableList.toImmutableList());
Optional<MarkJoinSlotReference> markJoinSlotReference = Optional.empty();
if (join.getMarkJoinSlotReference().isPresent()) {
markJoinSlotReference = Optional.of((MarkJoinSlotReference) ExpressionDeepCopier.INSTANCE
.deepCopy(join.getMarkJoinSlotReference().get(), context));

}
return new LogicalJoin<>(join.getJoinType(), hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
join.getDistributeHint(), join.getMarkJoinSlotReference(), children, join.getJoinReorderContext());
join.getDistributeHint(), markJoinSlotReference, children, join.getJoinReorderContext());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
Expand Down Expand Up @@ -81,6 +82,9 @@ void testOrExpandCTE() {
Assertions.assertTrue(plan instanceof LogicalCTEAnchor);
Assertions.assertTrue(plan.child(1) instanceof LogicalCTEAnchor);
Assertions.assertTrue(plan.child(1).child(1) instanceof LogicalCTEAnchor);
Assertions.assertTrue(plan.child(1).child(1).anyMatch(x -> x instanceof LogicalCTEConsumer));
Assertions.assertTrue(plan.child(1).child(1).child(1) instanceof LogicalCTEAnchor);
Assertions.assertTrue(plan.child(1).child(1).child(1)
.anyMatch(x -> x instanceof LogicalCTEConsumer));
}
}

0 comments on commit bac5172

Please sign in to comment.