Skip to content

Commit

Permalink
Add shrinkPool implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
tanjialiang committed Feb 27, 2024
1 parent 11eab1a commit 4b9d921
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 26 deletions.
4 changes: 2 additions & 2 deletions velox/common/memory/MemoryArbitrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ class NoopArbitrator : public MemoryArbitrator {
// Noop arbitrator has no memory capacity limit so no operation needed for
// memory pool capacity shrink.
uint64_t shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& /*unused*/,
uint64_t /*unused*/) override {
const std::vector<std::shared_ptr<MemoryPool>>& /* unused */,
uint64_t /* unused */) override {
return 0;
}

Expand Down
10 changes: 5 additions & 5 deletions velox/common/memory/MemoryArbitrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ class MemoryArbitrator {
virtual uint64_t shrinkCapacity(MemoryPool* pool, uint64_t targetBytes) = 0;

/// Invoked by the memory manager to shrink memory capacity from a given list
/// of memory pools by reclaiming used memory. The freed memory capacity is
/// given back to the arbitrator. The function returns the actual freed memory
/// capacity in bytes.
/// of memory pools by reclaiming free and used memory. The freed memory
/// capacity is given back to the arbitrator. The function returns the actual
/// freed memory capacity in bytes.
virtual uint64_t shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& pools,
uint64_t targetBytes) = 0;
Expand Down Expand Up @@ -361,14 +361,14 @@ class MemoryReclaimer {
/// the memory reservations during memory arbitration should come from the
/// spilling memory pool.
struct MemoryArbitrationContext {
const MemoryPool& requestor;
const MemoryPool* requestor;
};

/// Object used to set/restore the memory arbitration context when a thread is
/// under memory arbitration processing.
class ScopedMemoryArbitrationContext {
public:
explicit ScopedMemoryArbitrationContext(const MemoryPool& requestor);
explicit ScopedMemoryArbitrationContext(const MemoryPool* requestor);
~ScopedMemoryArbitrationContext();

private:
Expand Down
26 changes: 24 additions & 2 deletions velox/common/memory/tests/MockSharedArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,30 @@ TEST_F(MockSharedArbitrationTest, arbitrationFailsTask) {
}

TEST_F(MockSharedArbitrationTest, shrinkMemory) {
std::vector<std::shared_ptr<MemoryPool>> pools;
ASSERT_THROW(arbitrator_->shrinkCapacity(pools, 128), VeloxException);
auto task1 = addTask(64 * MB);
auto op1 = addMemoryOp(task1);
auto task2 = addTask(64 * MB);
auto op2 = addMemoryOp(task2);

op1->allocate(64 * MB);
op1->freeAll();
auto bufOp1_1 = op1->allocate(32 * MB);
auto bufOp1_2 = op1->allocate(32 * MB);
op1->free(bufOp1_1);
ASSERT_EQ(op1->pool()->root()->capacity(), 64 * MB);
ASSERT_EQ(op1->pool()->root()->currentBytes(), 32 * MB);

op2->allocate(64 * MB);
op2->freeAll();
auto bufOp2_1 = op2->allocate(32 * MB);
auto bufOp2_2 = op2->allocate(32 * MB);
op2->free(bufOp2_1);
ASSERT_EQ(op2->pool()->root()->capacity(), 64 * MB);
ASSERT_EQ(op2->pool()->root()->currentBytes(), 32 * MB);

ASSERT_EQ(manager_->shrinkPools(kMaxMemory), 128 * MB);
ASSERT_EQ(op1->capacity(), 0);
ASSERT_EQ(op2->capacity(), 0);
}

TEST_F(MockSharedArbitrationTest, singlePoolGrowWithoutArbitration) {
Expand Down
56 changes: 44 additions & 12 deletions velox/exec/SharedArbitrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,24 @@ uint64_t SharedArbitrator::shrinkCapacity(
return freedBytes;
}

uint64_t SharedArbitrator::shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& pools,
uint64_t targetBytes) {
ScopedArbitration scopedArbitration(this);
targetBytes = std::max(memoryPoolTransferCapacity_, targetBytes);
std::vector<Candidate> candidates;
candidates.reserve(pools.size());
candidates = getCandidateStats(pools);
auto freedBytes = reclaimFreeMemoryFromCandidates(candidates, targetBytes);
if (freedBytes >= targetBytes) {
return freedBytes;
}
freedBytes += reclaimUsedMemoryFromCandidates(
nullptr, candidates, targetBytes - freedBytes);
incrementFreeCapacity(freedBytes);
return freedBytes;
}

std::vector<SharedArbitrator::Candidate> SharedArbitrator::getCandidateStats(
const std::vector<std::shared_ptr<MemoryPool>>& pools) {
std::vector<SharedArbitrator::Candidate> candidates;
Expand Down Expand Up @@ -437,7 +455,8 @@ uint64_t SharedArbitrator::reclaimUsedMemoryFromCandidates(
targetBytes - freedBytes, memoryPoolTransferCapacity_);
VELOX_CHECK_GT(bytesToReclaim, 0);
freedBytes += reclaim(candidate.pool, bytesToReclaim);
if ((freedBytes >= targetBytes) || requestor->aborted()) {
if ((freedBytes >= targetBytes) ||
(requestor != nullptr && requestor->aborted())) {
break;
}
}
Expand Down Expand Up @@ -580,22 +599,39 @@ std::string SharedArbitrator::toStringLocked() const {
statsLocked().toString());
}

SharedArbitrator::ScopedArbitration::ScopedArbitration(
SharedArbitrator* arbitrator)
: requestor_(nullptr),
arbitrator_(arbitrator),
startTime_(std::chrono::steady_clock::now()),
arbitrationCtx_(requestor_) {
VELOX_CHECK_NOT_NULL(arbitrator_);
arbitrator_->startArbitration("Wait for arbitration, global requested");
}

SharedArbitrator::ScopedArbitration::ScopedArbitration(
MemoryPool* requestor,
SharedArbitrator* arbitrator)
: requestor_(requestor),
arbitrator_(arbitrator),
startTime_(std::chrono::steady_clock::now()),
arbitrationCtx_(*requestor_) {
arbitrationCtx_(requestor_) {
VELOX_CHECK_NOT_NULL(arbitrator_);
arbitrator_->startArbitration(requestor);
if (arbitrator_->arbitrationStateCheckCb_ != nullptr) {
requestor->enterArbitration();
arbitrator_->startArbitration(fmt::format(
"Wait for arbitration, requestor: {}[{}]",
requestor->name(),
requestor->root()->name()));
if (arbitrator_->arbitrationStateCheckCb_ != nullptr &&
requestor != nullptr) {
arbitrator_->arbitrationStateCheckCb_(*requestor);
}
}

SharedArbitrator::ScopedArbitration::~ScopedArbitration() {
requestor_->leaveArbitration();
if (requestor_ != nullptr) {
requestor_->leaveArbitration();
}
const auto arbitrationTimeUs =
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - startTime_)
Expand All @@ -609,18 +645,14 @@ SharedArbitrator::ScopedArbitration::~ScopedArbitration() {
arbitrator_->finishArbitration();
}

void SharedArbitrator::startArbitration(MemoryPool* requestor) {
requestor->enterArbitration();
void SharedArbitrator::startArbitration(const std::string& arbitrationContext) {
ContinueFuture waitPromise{ContinueFuture::makeEmpty()};
{
std::lock_guard<std::mutex> l(mutex_);
RECORD_METRIC_VALUE(kMetricArbitratorRequestsCount);
++numRequests_;
if (running_) {
waitPromises_.emplace_back(fmt::format(
"Wait for arbitration, requestor: {}[{}]",
requestor->name(),
requestor->root()->name()));
waitPromises_.emplace_back(arbitrationContext);
waitPromise = waitPromises_.back().getSemiFuture();
} else {
VELOX_CHECK(waitPromises_.empty());
Expand All @@ -629,7 +661,7 @@ void SharedArbitrator::startArbitration(MemoryPool* requestor) {
}

TestValue::adjust(
"facebook::velox::memory::SharedArbitrator::startArbitration", requestor);
"facebook::velox::memory::SharedArbitrator::startArbitration", nullptr);

if (waitPromise.valid()) {
uint64_t waitTimeUs{0};
Expand Down
13 changes: 8 additions & 5 deletions velox/exec/SharedArbitrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,8 @@ class SharedArbitrator : public memory::MemoryArbitrator {
uint64_t shrinkCapacity(MemoryPool* pool, uint64_t freedBytes) final;

uint64_t shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& /*unused*/,
uint64_t /*unused*/) override final {
VELOX_NYI("shrinkCapacity is not supported by SharedArbitrator");
}
const std::vector<std::shared_ptr<MemoryPool>>& pools,
uint64_t targetBytes) override final;

Stats stats() const final;

Expand All @@ -80,6 +78,11 @@ class SharedArbitrator : public memory::MemoryArbitrator {

class ScopedArbitration {
public:
// Used by arbitration request NOT initiated from a memory pool. E.g. global
// shrinkPools() API.
ScopedArbitration(SharedArbitrator* arbitrator);

// Used by arbitration request initiated from a memory pool.
ScopedArbitration(MemoryPool* requestor, SharedArbitrator* arbitrator);

~ScopedArbitration();
Expand Down Expand Up @@ -129,7 +132,7 @@ class SharedArbitrator : public memory::MemoryArbitrator {
// Invoked to start next memory arbitration request, and it will wait for the
// serialized execution if there is a running or other waiting arbitration
// requests.
void startArbitration(MemoryPool* requestor);
void startArbitration(const std::string& arbitrationContext);

// Invoked by a finished memory arbitration request to kick off the next
// arbitration request execution if there are any ones waiting.
Expand Down

0 comments on commit 4b9d921

Please sign in to comment.