Skip to content

Commit

Permalink
[feat](Nereids) Reject Commutativity Swap for Nested Loop Joins Affec…
Browse files Browse the repository at this point in the history
…ting Parallelism (#34639) (#34996)

pick from master #34639

This PR introduces a safeguard to prevent commutativity swaps in nested loop joins that would convert a parallelizable join into a non-parallelizable one, thereby preserving optimal query execution efficiency. By adding a function that assesses the impact of such swaps on parallelism, the system automatically rejects changes that would hinder performance, ensuring that joins can continue to be executed in parallel to fully utilize system resources and maintain high operational throughput.
  • Loading branch information
keanji-x authored May 21, 2024
1 parent 7bee558 commit 9e386c0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.BitmapContains;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.planner.NestedLoopJoinNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TRuntimeFilterType;

Expand Down Expand Up @@ -59,6 +62,11 @@ public Rule build() {
.whenNot(LogicalJoin::hasJoinHint)
.whenNot(join -> joinOrderMatchBitmapRuntimeFilterOrder(join))
.whenNot(LogicalJoin::isMarkJoin)
// For a nested loop join, if commutativity causes a join that could originally be executed
// in parallel to become non-parallelizable, then we reject this swap.
.whenNot(join -> JoinUtils.shouldNestedLoopJoin(join)
&& NestedLoopJoinNode.canParallelize(JoinType.toJoinOperator(join.getJoinType()))
&& !NestedLoopJoinNode.canParallelize(JoinType.toJoinOperator(join.getJoinType().swap())))
.then(join -> {
LogicalJoin<Plan, Plan> newJoin = join.withTypeChildren(join.getJoinType().swap(),
join.right(), join.left());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,16 @@ public NestedLoopJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, TableRe
tupleIds.addAll(inner.getOutputTupleIds());
}

public boolean canParallelize() {
public static boolean canParallelize(JoinOperator joinOp) {
return joinOp == JoinOperator.CROSS_JOIN || joinOp == JoinOperator.INNER_JOIN
|| joinOp == JoinOperator.LEFT_OUTER_JOIN || joinOp == JoinOperator.LEFT_SEMI_JOIN
|| joinOp == JoinOperator.LEFT_ANTI_JOIN || joinOp == JoinOperator.NULL_AWARE_LEFT_ANTI_JOIN;
}

public boolean canParallelize() {
return canParallelize(joinOp);
}

public void setJoinConjuncts(List<Expr> joinConjuncts) {
this.joinConjuncts = joinConjuncts;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.doris.nereids.rules.exploration.join;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
Expand All @@ -27,11 +29,12 @@
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;

import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;

public class JoinCommuteTest implements MemoPatternMatchSupported {
@Test
public void testInnerJoinCommute() {
void testInnerJoinCommute() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

Expand All @@ -51,4 +54,21 @@ public void testInnerJoinCommute() {
)
;
}

@Test
void testParallelJoinCommute() {
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);

LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
.build();
join = join.withJoinConjuncts(
ImmutableList.of(),
ImmutableList.of(new GreaterThan(scan1.getOutput().get(0), scan2.getOutput().get(0))));

PlanChecker.from(MemoTestUtils.createConnectContext(), join)
.applyExploration(JoinCommute.BUSHY.build())
.printlnTree();
}
}

0 comments on commit 9e386c0

Please sign in to comment.