Skip to content

Commit

Permalink
Add shrinkPool implementation (facebookincubator#8865)
Browse files Browse the repository at this point in the history
Summary:
Add shrinkPool implementation to memory manager and arbitrator.

Pull Request resolved: facebookincubator#8865

Reviewed By: xiaoxmeng

Differential Revision: D54274877

Pulled By: tanjialiang

fbshipit-source-id: 631d4cc04f7932bf16823d7ef5c222ce932503b5
  • Loading branch information
tanjialiang authored and facebook-github-bot committed Feb 28, 2024
1 parent 2078f23 commit e6a986c
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 38 deletions.
6 changes: 3 additions & 3 deletions velox/common/memory/MemoryArbitrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ class NoopArbitrator : public MemoryArbitrator {
// Noop arbitrator has no memory capacity limit so no operation needed for
// memory pool capacity shrink.
uint64_t shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& /*unused*/,
uint64_t /*unused*/) override {
const std::vector<std::shared_ptr<MemoryPool>>& /* unused */,
uint64_t /* unused */) override {
return 0;
}

Expand Down Expand Up @@ -440,7 +440,7 @@ bool MemoryArbitrator::Stats::operator<=(const Stats& other) const {
}

ScopedMemoryArbitrationContext::ScopedMemoryArbitrationContext(
const MemoryPool& requestor)
const MemoryPool* requestor)
: savedArbitrationCtx_(arbitrationCtx),
currentArbitrationCtx_({.requestor = requestor}) {
arbitrationCtx = &currentArbitrationCtx_;
Expand Down
10 changes: 5 additions & 5 deletions velox/common/memory/MemoryArbitrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ class MemoryArbitrator {
virtual uint64_t shrinkCapacity(MemoryPool* pool, uint64_t targetBytes) = 0;

/// Invoked by the memory manager to shrink memory capacity from a given list
/// of memory pools by reclaiming used memory. The freed memory capacity is
/// given back to the arbitrator. The function returns the actual freed memory
/// capacity in bytes.
/// of memory pools by reclaiming free and used memory. The freed memory
/// capacity is given back to the arbitrator. The function returns the actual
/// freed memory capacity in bytes.
virtual uint64_t shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& pools,
uint64_t targetBytes) = 0;
Expand Down Expand Up @@ -361,14 +361,14 @@ class MemoryReclaimer {
/// the memory reservations during memory arbitration should come from the
/// spilling memory pool.
struct MemoryArbitrationContext {
const MemoryPool& requestor;
const MemoryPool* requestor;
};

/// Object used to set/restore the memory arbitration context when a thread is
/// under memory arbitration processing.
class ScopedMemoryArbitrationContext {
public:
explicit ScopedMemoryArbitrationContext(const MemoryPool& requestor);
explicit ScopedMemoryArbitrationContext(const MemoryPool* requestor);
~ScopedMemoryArbitrationContext();

private:
Expand Down
16 changes: 8 additions & 8 deletions velox/common/memory/tests/MemoryArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,29 +681,29 @@ TEST_F(MemoryReclaimerTest, arbitrationContext) {
ASSERT_FALSE(isSpillMemoryPool(leafChild2.get()));
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
{
ScopedMemoryArbitrationContext arbitrationContext(*leafChild1);
ScopedMemoryArbitrationContext arbitrationContext(leafChild1.get());
ASSERT_TRUE(memoryArbitrationContext() != nullptr);
ASSERT_EQ(&memoryArbitrationContext()->requestor, leafChild1.get());
ASSERT_EQ(memoryArbitrationContext()->requestor, leafChild1.get());
}
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
{
ScopedMemoryArbitrationContext arbitrationContext(*leafChild2);
ScopedMemoryArbitrationContext arbitrationContext(leafChild2.get());
ASSERT_TRUE(memoryArbitrationContext() != nullptr);
ASSERT_EQ(&memoryArbitrationContext()->requestor, leafChild2.get());
ASSERT_EQ(memoryArbitrationContext()->requestor, leafChild2.get());
}
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
std::thread nonAbitrationThread([&]() {
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
{
ScopedMemoryArbitrationContext arbitrationContext(*leafChild1);
ScopedMemoryArbitrationContext arbitrationContext(leafChild1.get());
ASSERT_TRUE(memoryArbitrationContext() != nullptr);
ASSERT_EQ(&memoryArbitrationContext()->requestor, leafChild1.get());
ASSERT_EQ(memoryArbitrationContext()->requestor, leafChild1.get());
}
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
{
ScopedMemoryArbitrationContext arbitrationContext(*leafChild2);
ScopedMemoryArbitrationContext arbitrationContext(leafChild2.get());
ASSERT_TRUE(memoryArbitrationContext() != nullptr);
ASSERT_EQ(&memoryArbitrationContext()->requestor, leafChild2.get());
ASSERT_EQ(memoryArbitrationContext()->requestor, leafChild2.get());
}
ASSERT_TRUE(memoryArbitrationContext() == nullptr);
});
Expand Down
2 changes: 1 addition & 1 deletion velox/common/memory/tests/MemoryPoolTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3662,7 +3662,7 @@ TEST_P(MemoryPoolTest, overuseUnderArbitration) {
ASSERT_FALSE(child->maybeReserve(2 * kMaxSize));
ASSERT_EQ(child->currentBytes(), 0);
ASSERT_EQ(child->reservedBytes(), 0);
ScopedMemoryArbitrationContext scopedMemoryArbitration(*child);
ScopedMemoryArbitrationContext scopedMemoryArbitration(child.get());
ASSERT_TRUE(underMemoryArbitration());
ASSERT_TRUE(child->maybeReserve(2 * kMaxSize));
ASSERT_EQ(child->currentBytes(), 0);
Expand Down
26 changes: 24 additions & 2 deletions velox/common/memory/tests/MockSharedArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,30 @@ TEST_F(MockSharedArbitrationTest, arbitrationFailsTask) {
}

TEST_F(MockSharedArbitrationTest, shrinkMemory) {
std::vector<std::shared_ptr<MemoryPool>> pools;
ASSERT_THROW(arbitrator_->shrinkCapacity(pools, 128), VeloxException);
auto task1 = addTask(64 * MB);
auto op1 = addMemoryOp(task1);
auto task2 = addTask(64 * MB);
auto op2 = addMemoryOp(task2);

op1->allocate(64 * MB);
op1->freeAll();
auto bufOp11 = op1->allocate(32 * MB);
auto bufOp12 = op1->allocate(32 * MB);
op1->free(bufOp11);
ASSERT_EQ(op1->pool()->root()->capacity(), 64 * MB);
ASSERT_EQ(op1->pool()->root()->currentBytes(), 32 * MB);

op2->allocate(64 * MB);
op2->freeAll();
auto bufOp21 = op2->allocate(32 * MB);
auto bufOp22 = op2->allocate(32 * MB);
op2->free(bufOp21);
ASSERT_EQ(op2->pool()->root()->capacity(), 64 * MB);
ASSERT_EQ(op2->pool()->root()->currentBytes(), 32 * MB);

ASSERT_EQ(manager_->shrinkPools(kMaxMemory), 128 * MB);
ASSERT_EQ(op1->capacity(), 0);
ASSERT_EQ(op2->capacity(), 0);
}

TEST_F(MockSharedArbitrationTest, singlePoolGrowWithoutArbitration) {
Expand Down
56 changes: 43 additions & 13 deletions velox/exec/SharedArbitrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,22 @@ uint64_t SharedArbitrator::shrinkCapacity(
return freedBytes;
}

uint64_t SharedArbitrator::shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& pools,
uint64_t targetBytes) {
ScopedArbitration scopedArbitration(this);
targetBytes = std::max(memoryPoolTransferCapacity_, targetBytes);
std::vector<Candidate> candidates = getCandidateStats(pools);
auto freedBytes = reclaimFreeMemoryFromCandidates(candidates, targetBytes);
if (freedBytes >= targetBytes) {
return freedBytes;
}
freedBytes += reclaimUsedMemoryFromCandidates(
nullptr, candidates, targetBytes - freedBytes);
incrementFreeCapacity(freedBytes);
return freedBytes;
}

std::vector<SharedArbitrator::Candidate> SharedArbitrator::getCandidateStats(
const std::vector<std::shared_ptr<MemoryPool>>& pools) {
std::vector<SharedArbitrator::Candidate> candidates;
Expand Down Expand Up @@ -437,7 +453,8 @@ uint64_t SharedArbitrator::reclaimUsedMemoryFromCandidates(
targetBytes - freedBytes, memoryPoolTransferCapacity_);
VELOX_CHECK_GT(bytesToReclaim, 0);
freedBytes += reclaim(candidate.pool, bytesToReclaim);
if ((freedBytes >= targetBytes) || requestor->aborted()) {
if ((freedBytes >= targetBytes) ||
(requestor != nullptr && requestor->aborted())) {
break;
}
}
Expand Down Expand Up @@ -580,22 +597,39 @@ std::string SharedArbitrator::toStringLocked() const {
statsLocked().toString());
}

SharedArbitrator::ScopedArbitration::ScopedArbitration(
SharedArbitrator* arbitrator)
: requestor_(nullptr),
arbitrator_(arbitrator),
startTime_(std::chrono::steady_clock::now()),
arbitrationCtx_(requestor_) {
VELOX_CHECK_NOT_NULL(arbitrator_);
arbitrator_->startArbitration("Wait for arbitration");
}

SharedArbitrator::ScopedArbitration::ScopedArbitration(
MemoryPool* requestor,
SharedArbitrator* arbitrator)
: requestor_(requestor),
arbitrator_(arbitrator),
startTime_(std::chrono::steady_clock::now()),
arbitrationCtx_(*requestor_) {
arbitrationCtx_(requestor_) {
VELOX_CHECK_NOT_NULL(arbitrator_);
arbitrator_->startArbitration(requestor);
if (arbitrator_->arbitrationStateCheckCb_ != nullptr) {
arbitrator_->arbitrationStateCheckCb_(*requestor);
requestor_->enterArbitration();
arbitrator_->startArbitration(fmt::format(
"Wait for arbitration, requestor: {}[{}]",
requestor_->name(),
requestor_->root()->name()));
if (arbitrator_->arbitrationStateCheckCb_ != nullptr &&
requestor_ != nullptr) {
arbitrator_->arbitrationStateCheckCb_(*requestor_);
}
}

SharedArbitrator::ScopedArbitration::~ScopedArbitration() {
requestor_->leaveArbitration();
if (requestor_ != nullptr) {
requestor_->leaveArbitration();
}
const auto arbitrationTimeUs =
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - startTime_)
Expand All @@ -609,18 +643,14 @@ SharedArbitrator::ScopedArbitration::~ScopedArbitration() {
arbitrator_->finishArbitration();
}

void SharedArbitrator::startArbitration(MemoryPool* requestor) {
requestor->enterArbitration();
void SharedArbitrator::startArbitration(const std::string& arbitrationContext) {
ContinueFuture waitPromise{ContinueFuture::makeEmpty()};
{
std::lock_guard<std::mutex> l(mutex_);
RECORD_METRIC_VALUE(kMetricArbitratorRequestsCount);
++numRequests_;
if (running_) {
waitPromises_.emplace_back(fmt::format(
"Wait for arbitration, requestor: {}[{}]",
requestor->name(),
requestor->root()->name()));
waitPromises_.emplace_back(arbitrationContext);
waitPromise = waitPromises_.back().getSemiFuture();
} else {
VELOX_CHECK(waitPromises_.empty());
Expand All @@ -629,7 +659,7 @@ void SharedArbitrator::startArbitration(MemoryPool* requestor) {
}

TestValue::adjust(
"facebook::velox::memory::SharedArbitrator::startArbitration", requestor);
"facebook::velox::memory::SharedArbitrator::startArbitration", nullptr);

if (waitPromise.valid()) {
uint64_t waitTimeUs{0};
Expand Down
17 changes: 11 additions & 6 deletions velox/exec/SharedArbitrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ class SharedArbitrator : public memory::MemoryArbitrator {
uint64_t shrinkCapacity(MemoryPool* pool, uint64_t freedBytes) final;

uint64_t shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& /*unused*/,
uint64_t /*unused*/) override final {
VELOX_NYI("shrinkCapacity is not supported by SharedArbitrator");
}
const std::vector<std::shared_ptr<MemoryPool>>& pools,
uint64_t targetBytes) override final;

Stats stats() const final;

Expand All @@ -80,7 +78,14 @@ class SharedArbitrator : public memory::MemoryArbitrator {

class ScopedArbitration {
public:
ScopedArbitration(MemoryPool* requestor, SharedArbitrator* arbitrator);
// Used by arbitration request NOT initiated from memory pool, e.g. through
// shrinkPools() API.
explicit ScopedArbitration(SharedArbitrator* arbitrator);

// Used by arbitration request initiated from a memory pool.
explicit ScopedArbitration(
MemoryPool* requestor,
SharedArbitrator* arbitrator);

~ScopedArbitration();

Expand Down Expand Up @@ -129,7 +134,7 @@ class SharedArbitrator : public memory::MemoryArbitrator {
// Invoked to start next memory arbitration request, and it will wait for the
// serialized execution if there is a running or other waiting arbitration
// requests.
void startArbitration(MemoryPool* requestor);
void startArbitration(const std::string& arbitrationContext);

// Invoked by a finished memory arbitration request to kick off the next
// arbitration request execution if there are any ones waiting.
Expand Down

0 comments on commit e6a986c

Please sign in to comment.