Skip to content

Commit

Permalink
Let HashJoinBridge be able to reclaim memory in the middle state
Browse files Browse the repository at this point in the history
  • Loading branch information
tanjialiang committed Oct 22, 2024
1 parent 92a2644 commit 81a3bf0
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 70 deletions.
20 changes: 19 additions & 1 deletion velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,8 +768,26 @@ bool HashBuild::finishHashBuild() {
RuntimeCounter(timing.wallNanos, RuntimeCounter::Unit::kNanos));

addRuntimeStats();

// Setup spill function for spilling hash table directly from hash join
// bridge after transferring of table ownership.
HashJoinTableSpillFunc tableSpillFunc;
if (canReclaim()) {
VELOX_CHECK_NOT_NULL(spiller_);
tableSpillFunc = [hashBitRange = spiller_->hashBits(),
joinNode = joinNode_,
spillConfig = spillConfig(),
spillStats =
&spillStats_](std::shared_ptr<BaseHashTable> table) {
return spillHashJoinTable(
table, hashBitRange, joinNode, spillConfig, spillStats);
};
}
joinBridge_->setHashTable(
std::move(table_), std::move(spillPartitions), joinHasNullKeys_);
std::move(table_),
std::move(spillPartitions),
joinHasNullKeys_,
tableSpillFunc);
if (canSpill()) {
stateCleared_ = true;
}
Expand Down
59 changes: 56 additions & 3 deletions velox/exec/HashJoinBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,52 @@ void HashJoinBridge::addBuilder() {
++numBuilders_;
}

bool HashJoinBridge::canReclaim() const {
return tableSpillFunc_ != nullptr && !probeStarted_.load() &&
buildResult_.has_value() && buildResult_->table != nullptr &&
buildResult_->table->numDistinct() != 0;
}

uint64_t HashJoinBridge::reclaim() {
VELOX_CHECK(buildResult_.has_value());
VELOX_CHECK_NOT_NULL(buildResult_->table);

auto computeTableReservedBytes = [](std::vector<RowContainer*> allRows) {
uint64_t totalReservedBytes{0};
for (const auto* rowContainer : allRows) {
totalReservedBytes += rowContainer->pool()->reservedBytes();
}
return totalReservedBytes;
};
const auto oldMemUsage =
computeTableReservedBytes(buildResult_->table->allRows());

auto spillPartitionSet = tableSpillFunc_(buildResult_->table);
buildResult_->table->clear(true);

const auto reclaimedBytes =
oldMemUsage - computeTableReservedBytes(buildResult_->table->allRows());

auto spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet);
if (restoringSpillPartitionId_.has_value()) {
for (const auto& id : spillPartitionIdSet) {
VELOX_DCHECK_LT(
restoringSpillPartitionId_->partitionBitOffset(),
id.partitionBitOffset());
}
}
for (auto& partitionEntry : spillPartitionSet) {
const auto id = partitionEntry.first;
VELOX_CHECK_EQ(spillPartitionSets_.count(id), 0);
spillPartitionSets_.emplace(id, std::move(partitionEntry.second));
}
buildResult_->restoredPartitionId = restoringSpillPartitionId_;
buildResult_->spillPartitionIds = spillPartitionIdSet;
restoringSpillPartitionId_.reset();

return reclaimedBytes;
}

namespace {
// Create spiller for spilling the row container from one of the sub-table from
// 'table' to parallelize the table spilling. The function spills all the rows
Expand Down Expand Up @@ -188,7 +234,8 @@ SpillPartitionSet spillHashJoinTable(
void HashJoinBridge::setHashTable(
std::unique_ptr<BaseHashTable> table,
SpillPartitionSet spillPartitionSet,
bool hasNullKeys) {
bool hasNullKeys,
const HashJoinTableSpillFunc& tableSpillFunc) {
VELOX_CHECK_NOT_NULL(table, "setHashTable called with null table");

auto spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet);
Expand All @@ -199,7 +246,7 @@ void HashJoinBridge::setHashTable(
VELOX_CHECK(started_);
VELOX_CHECK(!buildResult_.has_value());
VELOX_CHECK(restoringSpillShards_.empty());

tableSpillFunc_ = tableSpillFunc;
if (restoringSpillPartitionId_.has_value()) {
for (const auto& id : spillPartitionIdSet) {
VELOX_DCHECK_LT(
Expand Down Expand Up @@ -266,7 +313,7 @@ std::optional<HashJoinBridge::HashBuildResult> HashJoinBridge::tableOrFuture(
!buildResult_.has_value() ||
(!restoringSpillPartitionId_.has_value() &&
restoringSpillShards_.empty()));

probeStarted_ = true;
if (buildResult_.has_value()) {
return buildResult_.value();
}
Expand All @@ -286,6 +333,7 @@ bool HashJoinBridge::probeFinished() {
!restoringSpillPartitionId_.has_value() &&
restoringSpillShards_.empty());
VELOX_CHECK_GT(numBuilders_, 0);
probeStarted_ = false;

// NOTE: we are clearing the hash table as it has been fully processed and
// not needed anymore. We'll wait for the HashBuild operator to build a new
Expand Down Expand Up @@ -374,6 +422,11 @@ uint64_t HashJoinMemoryReclaimer::reclaim(
}
return !hasReclaimedFromBuild;
});
auto joinBridge = joinBridge_.lock();
VELOX_CHECK_NOT_NULL(joinBridge);
if (reclaimedBytes == 0 && joinBridge->canReclaim()) {
reclaimedBytes = joinBridge->reclaim();
}
return reclaimedBytes;
}

Expand Down
70 changes: 18 additions & 52 deletions velox/exec/HashJoinBridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,34 @@ namespace test {
class HashJoinBridgeTestHelper;
}

using HashJoinTableSpillFunc =
std::function<SpillPartitionSet(std::shared_ptr<BaseHashTable>)>;

/// Hands over a hash table from a multi-threaded build pipeline to a
/// multi-threaded probe pipeline. This is owned by shared_ptr by all the build
/// and probe Operator instances concerned. Corresponds to the Presto concept of
/// the same name.
class HashJoinBridge : public JoinBridge {
public:
struct SpillResult {
Spiller* spiller{nullptr};
const std::exception_ptr error{nullptr};

explicit SpillResult(std::exception_ptr _error) : error(_error) {}
explicit SpillResult(Spiller* _spiller) : spiller(_spiller) {}
};

void start() override;

/// Invoked by HashBuild operator ctor to add to this bridge by incrementing
/// 'numBuilders_'. The latter is used to split the spill partition data among
/// HashBuild operators to parallelize the restoring operation.
void addBuilder();

/// Invoked to spill 'table' and returns spilled partitions. This method
/// should only be invoked when the 'table' is a ready-to-use one, meaning it
/// should not be one in the middle of building. Hence it is normally invoked
/// by probe side.
SpillPartitionSet spillTable(
std::shared_ptr<BaseHashTable> table,
folly::Synchronized<common::SpillStats>* stats);
bool canReclaim() const;

/// Triggers the parallel spilling directly from the provided 'spillers'. It
/// does not do other operations other than spill. Hence it can be invoked to
/// spill partially built table, and hence invoked by build side.
std::vector<std::unique_ptr<SpillResult>> spillTableFromSpillers(
const std::vector<Spiller*>& spillers);
uint64_t reclaim();

/// Invoked by the build operator to set the built hash table.
/// 'spillPartitionSet' contains the spilled partitions while building
/// 'table' which only applies if the disk spilling is enabled.
void setHashTable(
std::unique_ptr<BaseHashTable> table,
SpillPartitionSet spillPartitionSet,
bool hasNullKeys);
bool hasNullKeys,
const HashJoinTableSpillFunc& tableSpillFunc);

/// Invoked by the probe operator to set the spilled hash table while the
/// probing. The function puts the spilled table partitions into
Expand Down Expand Up @@ -139,31 +125,7 @@ class HashJoinBridge : public JoinBridge {
/// 'spillPartition' will be set to null in the returned SpillInput.
std::optional<SpillInput> spillInputOrFuture(ContinueFuture* future);

/// Sets the build table type.
void maybeSetTableType(const RowTypePtr& tableType);

/// Sets the spill configs.
void maybeSetSpillConfig(const common::SpillConfig* spillConfig);

/// Sets the join plan node 'this' is responsible for.
void maybeSetJoinNode(
const std::shared_ptr<const core::HashJoinNode>& joinNode);

private:
// Spills the row container from one of the sub-table from
// 'table' to parallelize the table spilling. The function
// spills all the rows from the row container and returns the spiller for the
// caller to collect the spilled partitions and stats.
std::unique_ptr<Spiller> createSpiller(
RowContainer* subTableRows,
folly::Synchronized<common::SpillStats>* stats);

// Returns the spill hash bit range for spilling the current
// 'buildResult_->table'.
HashBitRange tableSpillHashBitRange() const;

const common::SpillConfig* spillConfig() const;

uint32_t numBuilders_{0};

// The result of the build side. It is set by the last build operator when
Expand All @@ -190,10 +152,11 @@ class HashJoinBridge : public JoinBridge {
// memory and engages in recursive spilling.
SpillPartitionSet spillPartitionSets_;

// The row type used for hash table spilling.
RowTypePtr tableType_;
std::shared_ptr<const core::HashJoinNode> joinNode_;
std::optional<common::SpillConfig> spillConfig_;
// A flag indicating if any probe operator has poked 'this' join bridge to
// attempt to get table. It is reset for each table partition processing.
std::atomic_bool probeStarted_;

HashJoinTableSpillFunc tableSpillFunc_{nullptr};
friend test::HashJoinBridgeTestHelper;
};

Expand All @@ -204,9 +167,10 @@ bool isLeftNullAwareJoinWithFilter(

class HashJoinMemoryReclaimer final : public MemoryReclaimer {
public:
static std::unique_ptr<memory::MemoryReclaimer> create() {
static std::unique_ptr<memory::MemoryReclaimer> create(
std::shared_ptr<HashJoinBridge> joinBridge) {
return std::unique_ptr<memory::MemoryReclaimer>(
new HashJoinMemoryReclaimer());
new HashJoinMemoryReclaimer(joinBridge));
}

uint64_t reclaim(
Expand All @@ -216,7 +180,9 @@ class HashJoinMemoryReclaimer final : public MemoryReclaimer {
memory::MemoryReclaimer::Stats& stats) final;

private:
HashJoinMemoryReclaimer() : MemoryReclaimer() {}
HashJoinMemoryReclaimer(std::shared_ptr<HashJoinBridge> joinBridge)
: MemoryReclaimer(), joinBridge_(joinBridge) {}
std::weak_ptr<HashJoinBridge> joinBridge_;
};

/// Returns true if 'pool' is a hash build operator's memory pool. The check is
Expand Down
16 changes: 10 additions & 6 deletions velox/exec/Task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,9 @@ velox::memory::MemoryPool* Task::getOrAddNodePool(
return nodePools_[planNodeId];
}
childPools_.push_back(pool_->addAggregateChild(
fmt::format("node.{}", planNodeId), createNodeReclaimer(false)));
fmt::format("node.{}", planNodeId), createNodeReclaimer([&](){
return exec::ParallelMemoryReclaimer::create(queryCtx_->spillExecutor());
})));
auto* nodePool = childPools_.back().get();
nodePools_[planNodeId] = nodePool;
return nodePool;
Expand All @@ -499,22 +501,24 @@ memory::MemoryPool* Task::getOrAddJoinNodePool(
return nodePools_[nodeId];
}
childPools_.push_back(pool_->addAggregateChild(
fmt::format("node.{}", nodeId), createNodeReclaimer(true)));
fmt::format("node.{}", nodeId), createNodeReclaimer([&]() {
return HashJoinMemoryReclaimer::create(
getHashJoinBridgeLocked(splitGroupId, planNodeId));
})));
auto* nodePool = childPools_.back().get();
nodePools_[nodeId] = nodePool;
return nodePool;
}

std::unique_ptr<memory::MemoryReclaimer> Task::createNodeReclaimer(
bool isHashJoinNode) const {
std::function<std::unique_ptr<memory::MemoryReclaimer>()> reclaimerFactory)
const {
if (pool()->reclaimer() == nullptr) {
return nullptr;
}
// Sets memory reclaimer for the parent node memory pool on the first child
// operator construction which has set memory reclaimer.
return isHashJoinNode
? HashJoinMemoryReclaimer::create()
: exec::ParallelMemoryReclaimer::create(queryCtx_->spillExecutor());
return reclaimerFactory();
}

std::unique_ptr<memory::MemoryReclaimer> Task::createExchangeClientReclaimer()
Expand Down
3 changes: 2 additions & 1 deletion velox/exec/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,8 @@ class Task : public std::enable_shared_from_this<Task> {
// customized instance for hash join plan node, otherwise creates a default
// memory reclaimer.
std::unique_ptr<memory::MemoryReclaimer> createNodeReclaimer(
bool isHashJoinNode) const;
std::function<std::unique_ptr<memory::MemoryReclaimer>()>
reclaimerFactory) const;

// Creates a memory reclaimer instance for an exchange client if the task
// memory pool has set memory reclaimer. We don't support to reclaim memory
Expand Down
23 changes: 16 additions & 7 deletions velox/exec/tests/HashJoinBridgeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ TEST_P(HashJoinBridgeTest, withoutSpill) {
// Can't call any other APIs except addBuilder() before start a join bridge
// first.
VELOX_ASSERT_THROW(
joinBridge->setHashTable(createFakeHashTable(), {}, false), "");
joinBridge->setHashTable(createFakeHashTable(), {}, false, nullptr),
"");
VELOX_ASSERT_THROW(joinBridge->setAntiJoinHasNullKeys(), "");
VELOX_ASSERT_THROW(joinBridge->probeFinished(), "");
VELOX_ASSERT_THROW(joinBridge->tableOrFuture(&futures[0]), "");
Expand Down Expand Up @@ -204,9 +205,10 @@ TEST_P(HashJoinBridgeTest, withoutSpill) {
} else {
auto table = createFakeHashTable();
rawTable = table.get();
joinBridge->setHashTable(std::move(table), {}, false);
joinBridge->setHashTable(std::move(table), {}, false, nullptr);
VELOX_ASSERT_THROW(
joinBridge->setHashTable(createFakeHashTable(), {}, false), "");
joinBridge->setHashTable(createFakeHashTable(), {}, false, nullptr),
"");
}
ASSERT_TRUE(helper.buildResult().has_value());

Expand Down Expand Up @@ -317,10 +319,13 @@ TEST_P(HashJoinBridgeTest, withSpill) {
if (oneIn(2)) {
spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet);
joinBridge->setHashTable(
createFakeHashTable(), std::move(spillPartitionSet), false);
createFakeHashTable(),
std::move(spillPartitionSet),
false,
nullptr);
} else {
spillByProber = !spillPartitionSet.empty();
joinBridge->setHashTable(createFakeHashTable(), {}, false);
joinBridge->setHashTable(createFakeHashTable(), {}, false, nullptr);
}
hasMoreSpill = numSpilledPartitions > numRestoredPartitions;
}
Expand Down Expand Up @@ -446,9 +451,13 @@ TEST_P(HashJoinBridgeTest, multiThreading) {
auto spillPartitionSet =
makeFakeSpillPartitionSet(partitionBitOffset);
joinBridge->setHashTable(
createFakeHashTable(), std::move(spillPartitionSet), false);
createFakeHashTable(),
std::move(spillPartitionSet),
false,
nullptr);
} else {
joinBridge->setHashTable(createFakeHashTable(), {}, false);
joinBridge->setHashTable(
createFakeHashTable(), {}, false, nullptr);
}
}
for (auto& promise : promises) {
Expand Down
Loading

0 comments on commit 81a3bf0

Please sign in to comment.