From 0ef0ac8e4833a8772fb2f8e83c465a4fb038b88c Mon Sep 17 00:00:00 2001 From: Jia Ke Date: Fri, 28 Jun 2024 18:13:07 -0700 Subject: [PATCH] Enable right join in smj (#10148) Summary: The semantics of the right join are similar to the left join, so we referenced the implementation of the left join to achieve the implementation of the right join. Pull Request resolved: https://github.com/facebookincubator/velox/pull/10148 Reviewed By: bikramSingh91 Differential Revision: D59176120 Pulled By: pedroerp fbshipit-source-id: 95184725dfa5fea9317c822d7761507bc49fca9b --- velox/exec/MergeJoin.cpp | 177 +++++++++++++++++++++++------ velox/exec/MergeJoin.h | 18 ++- velox/exec/fuzzer/JoinFuzzer.cpp | 2 +- velox/exec/tests/MergeJoinTest.cpp | 143 +++++++++++++++++++---- 4 files changed, 280 insertions(+), 60 deletions(-) diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index b4281ea970d6..87e3dbf8648d 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -20,6 +20,13 @@ namespace facebook::velox::exec { +namespace { +bool supportsMergeJoin(std::shared_ptr joinNode) { + return joinNode->isInnerJoin() || joinNode->isLeftJoin() || + joinNode->isLeftSemiFilterJoin() || joinNode->isRightSemiFilterJoin() || + joinNode->isAntiJoin() || joinNode->isRightJoin(); +} +} // namespace MergeJoin::MergeJoin( int32_t operatorId, DriverCtx* driverCtx, @@ -35,10 +42,9 @@ MergeJoin::MergeJoin( numKeys_{joinNode->leftKeys().size()}, joinNode_(joinNode) { VELOX_USER_CHECK( - joinNode_->isInnerJoin() || joinNode_->isLeftJoin() || - joinNode_->isLeftSemiFilterJoin() || - joinNode_->isRightSemiFilterJoin() || joinNode_->isAntiJoin(), - "Merge join supports only inner, left and left semi joins. Other join types are not supported yet."); + supportsMergeJoin(joinNode_), + "The join type is not supported by merge join: ", + joinTypeName(joinNode_->joinType())); } void MergeJoin::initialize() { @@ -89,13 +95,14 @@ void MergeJoin::initialize() { if (joinNode_->filter()) { initializeFilter(joinNode_->filter(), leftType, rightType); - if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin()) { - leftJoinTracker_ = LeftJoinTracker(outputBatchSize_, pool()); + if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || + joinNode_->isRightJoin()) { + joinTracker_ = JoinTracker(outputBatchSize_, pool()); } } else if (joinNode_->isAntiJoin()) { // Anti join needs to track the left side rows that have no match on the // right. - leftJoinTracker_ = LeftJoinTracker(outputBatchSize_, pool()); + joinTracker_ = JoinTracker(outputBatchSize_, pool()); } joinNode_.reset(); @@ -183,6 +190,9 @@ BlockingReason MergeJoin::isBlocked(ContinueFuture* future) { } bool MergeJoin::needsInput() const { + if (isRightJoin(joinType_)) { + return (input_ == nullptr || rightInput_ == nullptr); + } return input_ == nullptr; } @@ -190,8 +200,8 @@ void MergeJoin::addInput(RowVectorPtr input) { input_ = std::move(input); index_ = 0; - if (leftJoinTracker_) { - leftJoinTracker_->resetLastVector(); + if (joinTracker_) { + joinTracker_->resetLastVector(); } } @@ -269,6 +279,7 @@ void copyRow( void MergeJoin::addOutputRowForLeftJoin( const RowVectorPtr& left, vector_size_t leftIndex) { + VELOX_USER_CHECK(isLeftJoin(joinType_) || isAntiJoin(joinType_)); rawLeftIndices_[outputSize_] = leftIndex; for (const auto& projection : rightProjections_) { @@ -276,9 +287,28 @@ void MergeJoin::addOutputRowForLeftJoin( target->setNull(outputSize_, true); } - if (leftJoinTracker_) { + if (joinTracker_) { // Record left-side row with no match on the right side. - leftJoinTracker_->addMiss(outputSize_); + joinTracker_->addMiss(outputSize_); + } + + ++outputSize_; +} + +void MergeJoin::addOutputRowForRightJoin( + const RowVectorPtr& right, + vector_size_t rightIndex) { + VELOX_USER_CHECK(isRightJoin(joinType_)); + rawRightIndices_[outputSize_] = rightIndex; + + for (const auto& projection : leftProjections_) { + const auto& target = output_->childAt(projection.outputChannel); + target->setNull(outputSize_, true); + } + + if (joinTracker_) { + // Record right-side row with no match on the left side. + joinTracker_->addMiss(outputSize_); } ++outputSize_; @@ -320,18 +350,23 @@ void MergeJoin::addOutputRow( copyRow(left, leftIndex, filterInput_, outputSize_, filterLeftInputs_); copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_); - if (leftJoinTracker_) { - // Record left-side row with a match on the right-side. - leftJoinTracker_->addMatch(left, leftIndex, outputSize_); + if (joinTracker_) { + if (isRightJoin(joinType_)) { + // Record right-side row with a match on the left-side. + joinTracker_->addMatch(right, rightIndex, outputSize_); + } else { + // Record left-side row with a match on the right-side. + joinTracker_->addMatch(left, leftIndex, outputSize_); + } } } // Anti join needs to track the left side rows that have no match on the // right. if (isAntiJoin(joinType_)) { - VELOX_CHECK(leftJoinTracker_); + VELOX_CHECK(joinTracker_); // Record left-side row with a match on the right-side. - leftJoinTracker_->addMatch(left, leftIndex, outputSize_); + joinTracker_->addMatch(left, leftIndex, outputSize_); } ++outputSize_; @@ -348,6 +383,10 @@ bool MergeJoin::prepareOutput( return true; } + if (isRightJoin(joinType_) && right != currentRight_) { + return true; + } + // If there is a new right, we need to flatten the dictionary. if (!isRightFlattened_ && right && currentRight_ != right) { flattenRightProjections(); @@ -363,14 +402,23 @@ bool MergeJoin::prepareOutput( rightIndices_ = allocateIndices(outputBatchSize_, pool()); rawRightIndices_ = rightIndices_->asMutable(); - // Create output dictionary vectors for left projections. + // Create left side projection outputs. std::vector localColumns(outputType_->size()); - for (const auto& projection : leftProjections_) { - localColumns[projection.outputChannel] = BaseVector::wrapInDictionary( - {}, - leftIndices_, - outputBatchSize_, - newLeft->childAt(projection.inputChannel)); + if (newLeft == nullptr) { + for (const auto& projection : leftProjections_) { + localColumns[projection.outputChannel] = BaseVector::create( + outputType_->childAt(projection.outputChannel), + outputBatchSize_, + operatorCtx_->pool()); + } + } else { + for (const auto& projection : leftProjections_) { + localColumns[projection.outputChannel] = BaseVector::wrapInDictionary( + {}, + leftIndices_, + outputBatchSize_, + newLeft->childAt(projection.inputChannel)); + } } currentLeft_ = newLeft; @@ -556,7 +604,7 @@ vector_size_t firstNonNull( RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) { auto numRows = output->size(); - const auto& filterRows = leftJoinTracker_->matchingRows(numRows); + const auto& filterRows = joinTracker_->matchingRows(numRows); auto numPassed = 0; BufferPtr indices = allocateIndices(numRows, pool()); @@ -738,6 +786,35 @@ RowVectorPtr MergeJoin::doGetOutput() { output_->resize(outputSize_); return std::move(output_); } + } else if (isRightJoin(joinType_)) { + if (rightInput_ && noMoreInput_) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(nullptr, rightInput_)) { + output_->resize(outputSize_); + return std::move(output_); + } + + while (true) { + if (outputSize_ == outputBatchSize_) { + return std::move(output_); + } + + addOutputRowForRightJoin(rightInput_, rightIndex_); + + ++rightIndex_; + if (rightIndex_ == rightInput_->size()) { + // Ran out of rows on the right side. + rightInput_ = nullptr; + return nullptr; + } + } + } + + if (noMoreRightInput_ && output_) { + output_->resize(outputSize_); + return std::move(output_); + } } else { if (noMoreInput_ || noMoreRightInput_) { if (output_) { @@ -770,9 +847,11 @@ RowVectorPtr MergeJoin::doGetOutput() { return std::move(output_); } addOutputRowForLeftJoin(input_, index_); + ++index_; + } else { + index_ = firstNonNull(input_, leftKeys_, index_ + 1); } - ++index_; if (index_ == input_->size()) { // Ran out of rows on the left side. input_ = nullptr; @@ -783,7 +862,24 @@ RowVectorPtr MergeJoin::doGetOutput() { // Catch up rightInput_ with input_. while (compareResult > 0) { - rightIndex_ = firstNonNull(rightInput_, rightKeys_, rightIndex_ + 1); + if (isRightJoin(joinType_)) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(nullptr, rightInput_)) { + output_->resize(outputSize_); + return std::move(output_); + } + + if (outputSize_ == outputBatchSize_) { + return std::move(output_); + } + + addOutputRowForRightJoin(rightInput_, rightIndex_); + ++rightIndex_; + } else { + rightIndex_ = firstNonNull(rightInput_, rightKeys_, rightIndex_ + 1); + } + if (rightIndex_ == rightInput_->size()) { // Ran out of rows on the right side. rightInput_ = nullptr; @@ -862,8 +958,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { auto rawIndices = indices->asMutable(); vector_size_t numPassed = 0; - if (leftJoinTracker_) { - const auto& filterRows = leftJoinTracker_->matchingRows(numRows); + if (joinTracker_) { + const auto& filterRows = joinTracker_->matchingRows(numRows); if (!filterRows.hasSelections()) { // No matches in the output, no need to evaluate the filter. @@ -878,9 +974,16 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (!isAntiJoin(joinType_)) { rawIndices[numPassed++] = row; - for (auto& projection : rightProjections_) { - auto target = output->childAt(projection.outputChannel); - target->setNull(row, true); + if (!isRightJoin(joinType_)) { + for (auto& projection : rightProjections_) { + auto target = output->childAt(projection.outputChannel); + target->setNull(row, true); + } + } else { + for (auto& projection : leftProjections_) { + auto target = output->childAt(projection.outputChannel); + target->setNull(row, true); + } } } }; @@ -890,7 +993,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const bool passed = !decodedFilterResult_.isNullAt(i) && decodedFilterResult_.valueAt(i); - leftJoinTracker_->processFilterResult(i, passed, onMiss); + joinTracker_->processFilterResult(i, passed, onMiss); if (isAntiJoin(joinType_)) { if (!passed) { @@ -927,8 +1030,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // 2. leftMatch_ may not be nullopt, but may be related to a different // (subsequent) left key. So we check if the last row in the batch has the // same left row number as the last key match. - if (!leftMatch_ || !leftJoinTracker_->isCurrentLeftMatch(numRows - 1)) { - leftJoinTracker_->noMoreFilterResults(onMiss); + if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch(numRows - 1)) { + joinTracker_->noMoreFilterResults(onMiss); } } else { filterRows_.resize(numRows); @@ -966,6 +1069,12 @@ void MergeJoin::evaluateFilter(const SelectivityVector& rows) { } bool MergeJoin::isFinished() { + if (isRightJoin(joinType_)) { + // If all rows on both the left and right sides match, we must also verify + // the 'noMoreInput_' on the left side to ensure that all results are + // complete. + return noMoreInput_ && noMoreRightInput_ && rightInput_ == nullptr; + } return noMoreInput_ && input_ == nullptr; } diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index e5f5b71999f1..42222f83ae2e 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -215,6 +215,13 @@ class MergeJoin : public Operator { const RowVectorPtr& left, vector_size_t leftIndex); + /// Adds one row of output for a right-side row with no left-side match. + /// Copies values from the 'rightIndex' row of 'right' and fills in nulls + /// for columns that correspond to the right side. + void addOutputRowForRightJoin( + const RowVectorPtr& right, + vector_size_t rightIndex); + /// Evaluates join filter on 'filterInput_' and returns 'output' that contains /// a subset of rows on which the filter passed. Returns nullptr if no rows /// passed the filter. @@ -231,9 +238,9 @@ class MergeJoin : public Operator { /// rows from the left side that have a match on the right. RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output); - /// As we populate the results of the left join, we track whether a given + /// As we populate the results of the join, we track whether a given /// output row is a result of a match between left and right sides or a miss. - /// We use LeftJoinTracker::addMatch and addMiss methods for that. + /// We use JoinTracker::addMatch and addMiss methods for that. /// /// The semantic of the filter is to include at least one left side row in the /// output after filters are applied. Therefore: @@ -256,8 +263,8 @@ class MergeJoin : public Operator { /// block, we keep the subset of passing rows. However, if the filter failed /// on all rows in such a block, we add one of these rows back and update /// build-side columns to null. - struct LeftJoinTracker { - LeftJoinTracker(vector_size_t numRows, memory::MemoryPool* pool) + struct JoinTracker { + JoinTracker(vector_size_t numRows, memory::MemoryPool* pool) : matchingRows_{numRows, false} { leftRowNumbers_ = AlignedBuffer::allocate(numRows, pool); rawLeftRowNumbers_ = leftRowNumbers_->asMutable(); @@ -391,7 +398,8 @@ class MergeJoin : public Operator { bool currentRowPassed_{false}; }; - std::optional leftJoinTracker_{std::nullopt}; + /// Used to record both left and right join. + std::optional joinTracker_{std::nullopt}; // Indices buffer used by the output dictionaries. All projection from the // left share `leftIndices_`, and projections in the right share diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 87bc26b15376..e3b969f46115 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -861,7 +861,7 @@ void JoinFuzzer::makeAlternativePlans( // Use OrderBy + MergeJoin if (joinNode->isInnerJoin() || joinNode->isLeftJoin() || joinNode->isLeftSemiFilterJoin() || joinNode->isRightSemiFilterJoin() || - joinNode->isAntiJoin()) { + joinNode->isAntiJoin() || joinNode->isRightJoin()) { auto planWithSplits = makeMergeJoinPlan( joinType, probeKeys, buildKeys, probeInput, buildInput, outputColumns); plans.push_back(planWithSplits); diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 27b99f7373ef..a91e62ca7b17 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -155,34 +155,69 @@ class MergeJoinTest : public HiveConnectorTestBase { // Test LEFT join. planNodeIdGenerator = std::make_shared(); - plan = PlanBuilder(planNodeIdGenerator) - .values(left) - .mergeJoin( - {"c0"}, - {"u_c0"}, - PlanBuilder(planNodeIdGenerator) - .values(right) - .project({"c1 as u_c1", "c0 as u_c0"}) - .planNode(), - "", - {"c0", "c1", "u_c1"}, - core::JoinType::kLeft) - .planNode(); + auto leftPlan = PlanBuilder(planNodeIdGenerator) + .values(left) + .mergeJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values(right) + .project({"c1 as u_c1", "c0 as u_c0"}) + .planNode(), + "", + {"c0", "c1", "u_c1"}, + core::JoinType::kLeft) + .planNode(); // Use very small output batch size. assertQuery( - makeCursorParameters(plan, 16), + makeCursorParameters(leftPlan, 16), "SELECT t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0"); // Use regular output batch size. assertQuery( - makeCursorParameters(plan, 1024), + makeCursorParameters(leftPlan, 1024), "SELECT t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0"); // Use very large output batch size. assertQuery( - makeCursorParameters(plan, 10'000), + makeCursorParameters(leftPlan, 10'000), "SELECT t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0"); + + // Test RIGHT join. + planNodeIdGenerator = std::make_shared(); + auto rightPlan = PlanBuilder(planNodeIdGenerator) + .values(right) + .mergeJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values(left) + .project({"c1 as u_c1", "c0 as u_c0"}) + .planNode(), + "", + {"u_c0", "u_c1", "c1"}, + core::JoinType::kRight) + .planNode(); + + // Use very small output batch size. + assertQuery( + makeCursorParameters(rightPlan, 16), + "SELECT t.c0, t.c1, u.c1 FROM u RIGHT JOIN t ON t.c0 = u.c0"); + + // Use regular output batch size. + assertQuery( + makeCursorParameters(rightPlan, 1024), + "SELECT t.c0, t.c1, u.c1 FROM u RIGHT JOIN t ON t.c0 = u.c0"); + + // Use very large output batch size. + assertQuery( + makeCursorParameters(rightPlan, 10'000), + "SELECT t.c0, t.c1, u.c1 FROM u RIGHT JOIN t ON t.c0 = u.c0"); + + // Test right join and left join with same result. + auto expectedResult = AssertQueryBuilder(leftPlan).copyResults(pool_.get()); + AssertQueryBuilder(rightPlan).assertResults(expectedResult); } }; @@ -346,7 +381,7 @@ TEST_F(MergeJoinTest, innerJoinFilter) { "SELECT t_c0, u_c0, u_c1 FROM t, u WHERE t_c0 = u_c0 AND (t_c1 + u_c1) % 2 = 0"); } -TEST_F(MergeJoinTest, leftJoinFilter) { +TEST_F(MergeJoinTest, leftAndRightJoinFilter) { // Each row on the left side has at most one match on the right side. auto left = makeRowVector( {"t_c0", "t_c1"}, @@ -366,7 +401,7 @@ TEST_F(MergeJoinTest, leftJoinFilter) { createDuckDbTable("u", {right}); auto planNodeIdGenerator = std::make_shared(); - auto plan = [&](const std::string& filter) { + auto leftPlan = [&](const std::string& filter) { return PlanBuilder(planNodeIdGenerator) .values({left}) .mergeJoin( @@ -379,11 +414,28 @@ TEST_F(MergeJoinTest, leftJoinFilter) { .planNode(); }; + auto rightPlan = [&](const std::string& filter) { + return PlanBuilder(planNodeIdGenerator) + .values({right}) + .mergeJoin( + {"u_c0"}, + {"t_c0"}, + PlanBuilder(planNodeIdGenerator).values({left}).planNode(), + filter, + {"t_c0", "t_c1", "u_c1"}, + core::JoinType::kRight) + .planNode(); + }; + // Test with different output batch sizes. for (auto batchSize : {1, 3, 16}) { assertQuery( - makeCursorParameters(plan("(t_c1 + u_c1) % 2 = 0"), batchSize), + makeCursorParameters(leftPlan("(t_c1 + u_c1) % 2 = 0"), batchSize), "SELECT t_c0, t_c1, u_c1 FROM t LEFT JOIN u ON t_c0 = u_c0 AND (t_c1 + u_c1) % 2 = 0"); + + assertQuery( + makeCursorParameters(rightPlan("(t_c1 + u_c1) % 2 = 0"), batchSize), + "SELECT t_c0, t_c1, u_c1 FROM u RIGHT JOIN t ON t_c0 = u_c0 AND (t_c1 + u_c1) % 2 = 0"); } // A left-side row with multiple matches on the right side. @@ -412,10 +464,15 @@ TEST_F(MergeJoinTest, leftJoinFilter) { "t_c1 + u_c1 > 100", "t_c1 + u_c1 < 100"}) { assertQuery( - makeCursorParameters(plan(filter), batchSize), + makeCursorParameters(leftPlan(filter), batchSize), fmt::format( "SELECT t_c0, t_c1, u_c1 FROM t LEFT JOIN u ON t_c0 = u_c0 AND {}", filter)); + assertQuery( + makeCursorParameters(rightPlan(filter), batchSize), + fmt::format( + "SELECT t_c0, t_c1, u_c1 FROM u RIGHT JOIN t ON t_c0 = u_c0 AND {}", + filter)); } } } @@ -592,6 +649,52 @@ TEST_F(MergeJoinTest, semiJoin) { core::JoinType::kRightSemiFilter); } +TEST_F(MergeJoinTest, rightJoin) { + auto left = makeRowVector( + {"t0"}, + {makeNullableFlatVector( + {1, 2, std::nullopt, 5, 6, std::nullopt})}); + + auto right = makeRowVector( + {"u0"}, + {makeNullableFlatVector( + {1, 5, 6, 8, std::nullopt, std::nullopt})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Right join. + auto planNodeIdGenerator = std::make_shared(); + auto rightPlan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "t0 > 2", + {"t0", "u0"}, + core::JoinType::kRight) + .planNode(); + AssertQueryBuilder(rightPlan, duckDbQueryRunner_) + .assertResults( + "SELECT * FROM t RIGHT JOIN u ON t.t0 = u.u0 AND t.t0 > 2"); + + auto leftPlan = + PlanBuilder(planNodeIdGenerator) + .values({right}) + .mergeJoin( + {"u0"}, + {"t0"}, + PlanBuilder(planNodeIdGenerator).values({left}).planNode(), + "t0 > 2", + {"t0", "u0"}, + core::JoinType::kLeft) + .planNode(); + auto expectedResult = AssertQueryBuilder(leftPlan).copyResults(pool_.get()); + AssertQueryBuilder(rightPlan).assertResults(expectedResult); +} + TEST_F(MergeJoinTest, nullKeys) { auto left = makeRowVector( {"t0"}, {makeNullableFlatVector({1, 2, 5, std::nullopt})});