From dcaae293fcfb0eb4713dca3361800abe8104f7b2 Mon Sep 17 00:00:00 2001 From: Xiaoxuan Meng Date: Tue, 24 Sep 2024 05:39:37 -0700 Subject: [PATCH] Add arbitration participant and operation objects to support global memory arbitration optimization (#11074) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11074 Add arbitration participant object to provide all the required arbitration operations and state management on a query memory pool inside the memory arbitrator, such as arbitration queue to serialize the arbitration request execution from the same query and the serialize the reclaim, shrink, grow and abort from either arbitration request and the background memory arbitrations. Add arbitration operation object to manage a memory arbitration request execution Reviewed By: tanjialiang Differential Revision: D63055730 fbshipit-source-id: 0db85eccf6c383807eb006f1fcfad8cb0b0aa596 --- velox/common/memory/ArbitrationOperation.cpp | 147 ++ velox/common/memory/ArbitrationOperation.h | 176 ++ .../common/memory/ArbitrationParticipant.cpp | 390 ++++ velox/common/memory/ArbitrationParticipant.h | 353 ++++ velox/common/memory/MemoryPool.h | 6 + .../tests/ArbitrationParticipantTest.cpp | 1790 +++++++++++++++++ 6 files changed, 2862 insertions(+) create mode 100644 velox/common/memory/ArbitrationOperation.cpp create mode 100644 velox/common/memory/ArbitrationOperation.h create mode 100644 velox/common/memory/ArbitrationParticipant.cpp create mode 100644 velox/common/memory/ArbitrationParticipant.h create mode 100644 velox/common/memory/tests/ArbitrationParticipantTest.cpp diff --git a/velox/common/memory/ArbitrationOperation.cpp b/velox/common/memory/ArbitrationOperation.cpp new file mode 100644 index 000000000000..2a6ca38981fd --- /dev/null +++ b/velox/common/memory/ArbitrationOperation.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/memory/ArbitrationOperation.h" +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/memory/Memory.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" + +using facebook::velox::common::testutil::TestValue; + +namespace facebook::velox::memory { +using namespace facebook::velox::memory; + +ArbitrationOperation::ArbitrationOperation( + ScopedArbitrationParticipant&& participant, + uint64_t requestBytes, + uint64_t timeoutMs) + : requestBytes_(requestBytes), + timeoutMs_(timeoutMs), + createTimeMs_(getCurrentTimeMs()), + participant_(std::move(participant)) { + VELOX_CHECK_GT(requestBytes_, 0); +} + +ArbitrationOperation::~ArbitrationOperation() { + VELOX_CHECK_NE( + state_, + State::kRunning, + "Unexpected arbitration operation state on destruction"); + VELOX_CHECK( + allocatedBytes_ == 0 || allocatedBytes_ >= requestBytes_, + "Unexpected allocatedBytes_ {} vs requestBytes_ {}", + succinctBytes(allocatedBytes_), + succinctBytes(requestBytes_)); +} + +std::string ArbitrationOperation::stateName(State state) { + switch (state) { + case State::kInit: + return "init"; + case State::kWaiting: + return "waiting"; + case State::kRunning: + return "running"; + case State::kFinished: + return "finished"; + default: + return fmt::format("unknown state: {}", static_cast(state)); + } +} + +void ArbitrationOperation::setState(State state) { + switch (state) { + case State::kWaiting: + VELOX_CHECK_EQ(state_, State::kInit); + break; + case State::kRunning: + VELOX_CHECK(this->state_ == State::kWaiting || state_ == State::kInit); + break; + case State::kFinished: + VELOX_CHECK_EQ(this->state_, State::kRunning); + break; + default: + VELOX_UNREACHABLE( + "Unexpected state transition from {} to {}", state_, state); + break; + } + state_ = state; +} + +void ArbitrationOperation::start() { + VELOX_CHECK_EQ(state_, State::kInit); + participant_->startArbitration(this); + setState(ArbitrationOperation::State::kRunning); +} + +void ArbitrationOperation::finish() { + setState(State::kFinished); + VELOX_CHECK_EQ(finishTimeMs_, 0); + finishTimeMs_ = getCurrentTimeMs(); + participant_->finishArbitration(this); +} + +bool ArbitrationOperation::aborted() const { + return participant_->aborted(); +} + +size_t ArbitrationOperation::executionTimeMs() const { + if (state_ == State::kFinished) { + VELOX_CHECK_GE(finishTimeMs_, createTimeMs_); + return finishTimeMs_ - createTimeMs_; + } else { + const auto currentTimeMs = getCurrentTimeMs(); + VELOX_CHECK_GE(currentTimeMs, createTimeMs_); + return currentTimeMs - createTimeMs_; + } +} + +bool ArbitrationOperation::hasTimeout() const { + return state_ != State::kFinished && timeoutMs() <= 0; +} + +size_t ArbitrationOperation::timeoutMs() const { + if (state_ == State::kFinished) { + return 0; + } + const auto execTimeMs = executionTimeMs(); + if (execTimeMs >= timeoutMs_) { + return 0; + } + return timeoutMs_ - execTimeMs; +} + +void ArbitrationOperation::setGrowTargets() { + // We shall only set grow targets once after start execution. + VELOX_CHECK_EQ(state_, State::kRunning); + VELOX_CHECK( + maxGrowBytes_ == 0 && minGrowBytes_ == 0, + "Arbitration operation grow targets have already been set: {}/{}", + succinctBytes(maxGrowBytes_), + succinctBytes(minGrowBytes_)); + participant_->getGrowTargets(requestBytes_, maxGrowBytes_, minGrowBytes_); + VELOX_CHECK_LE(requestBytes_, maxGrowBytes_); +} + +std::ostream& operator<<(std::ostream& out, ArbitrationOperation::State state) { + out << ArbitrationOperation::stateName(state); + return out; +} +} // namespace facebook::velox::memory diff --git a/velox/common/memory/ArbitrationOperation.h b/velox/common/memory/ArbitrationOperation.h new file mode 100644 index 000000000000..884d601b6027 --- /dev/null +++ b/velox/common/memory/ArbitrationOperation.h @@ -0,0 +1,176 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/base/Counters.h" +#include "velox/common/base/GTestMacros.h" +#include "velox/common/base/StatsReporter.h" +#include "velox/common/future/VeloxPromise.h" +#include "velox/common/memory/ArbitrationParticipant.h" +#include "velox/common/memory/Memory.h" + +namespace facebook::velox::memory { + +/// Manages the execution of a memory arbitration request within the arbitrator. +class ArbitrationOperation { + public: + ArbitrationOperation( + ScopedArbitrationParticipant&& pool, + uint64_t requestBytes, + uint64_t timeoutMs); + + ~ArbitrationOperation(); + + enum class State { + kInit, + kWaiting, + kRunning, + kFinished, + }; + + State state() const { + return state_; + } + + static std::string stateName(State state); + + /// Returns the corresponding arbitration participant. + const ScopedArbitrationParticipant& participant() { + return participant_; + } + + /// Invoked to start arbitration execution on the arbitration participant. The + /// latter ensures the serialized execution of arbitration operations from the + /// same query with one at a time. So this method blocks until all the prior + /// arbitration operations finish. + void start(); + + /// Invoked to finish arbitration execution on the arbitration participant. It + /// also resumes the next waiting arbitration operation to execute if there is + /// one. + void finish(); + + /// Returns true if the corresponding arbitration participant has been + /// aborted. + bool aborted() const; + + /// Invoked to set the grow targets for this arbitration operation based on + /// the request size. + /// + /// NOTE: this should be called once after the arbitration operation is + /// started. + void setGrowTargets(); + + uint64_t requestBytes() const { + return requestBytes_; + } + + /// Returns the max grow bytes for this arbitration operation which could be + /// larger than the request bytes for exponential growth. + uint64_t maxGrowBytes() const { + return maxGrowBytes_; + } + + /// Returns the min grow bytes for this arbitration operation to ensure the + /// arbitration participant has the minimum amount of memory capacity. The + /// arbitrator might allocate memory from the reserved memory capacity pool + /// for the min grow bytes. + uint64_t minGrowBytes() const { + return minGrowBytes_; + } + + /// Returns the allocated bytes by this arbitration operation. + uint64_t& allocatedBytes() { + return allocatedBytes_; + } + + /// Returns the remaining execution time for this operation before time out. + /// If the operation has already finished, this returns zero. + size_t timeoutMs() const; + + /// Returns true if this operation has timed out. + bool hasTimeout() const; + + /// Returns the execution time of this arbitration operation since creation. + size_t executionTimeMs() const; + + /// Getters/Setters of the wait time in (local) arbitration paritcipant wait + /// queue or (global) arbitrator request wait queue. + void setLocalArbitrationWaitTimeUs(uint64_t waitTimeUs) { + VELOX_CHECK_EQ(localArbitrationWaitTimeUs_, 0); + VELOX_CHECK_EQ(state_, State::kWaiting); + localArbitrationWaitTimeUs_ = waitTimeUs; + } + + uint64_t localArbitrationWaitTimeUs() const { + return localArbitrationWaitTimeUs_; + } + + void setGlobalArbitrationWaitTimeUs(uint64_t waitTimeUs) { + VELOX_CHECK_EQ(globalArbitrationWaitTimeUs_, 0); + VELOX_CHECK_EQ(state_, State::kRunning); + globalArbitrationWaitTimeUs_ = waitTimeUs; + } + + uint64_t globalArbitrationWaitTimeUs() const { + return globalArbitrationWaitTimeUs_; + } + + private: + void setState(State state); + + const uint64_t requestBytes_; + const uint64_t timeoutMs_; + + // The start time of this arbitration operation. + const uint64_t createTimeMs_; + const ScopedArbitrationParticipant participant_; + + State state_{State::kInit}; + + uint64_t finishTimeMs_{0}; + + uint64_t maxGrowBytes_{0}; + uint64_t minGrowBytes_{0}; + + // The actual bytes allocated from arbitrator based on the request bytes and + // grow targets. It is either zero on failure or between 'requestBytes_' and + // 'maxGrowBytes_' on success. + uint64_t allocatedBytes_{0}; + + // The time that waits in local arbitration queue. + uint64_t localArbitrationWaitTimeUs_{0}; + + // The time that waits for global arbitration queue. + uint64_t globalArbitrationWaitTimeUs_{0}; + + friend class ArbitrationParticipant; +}; + +std::ostream& operator<<(std::ostream& out, ArbitrationOperation::State state); +} // namespace facebook::velox::memory + +template <> +struct fmt::formatter + : formatter { + auto format( + facebook::velox::memory::ArbitrationOperation::State state, + format_context& ctx) { + return formatter::format( + facebook::velox::memory::ArbitrationOperation::stateName(state), ctx); + } +}; diff --git a/velox/common/memory/ArbitrationParticipant.cpp b/velox/common/memory/ArbitrationParticipant.cpp new file mode 100644 index 000000000000..ceaea9dd888a --- /dev/null +++ b/velox/common/memory/ArbitrationParticipant.cpp @@ -0,0 +1,390 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/memory/ArbitrationParticipant.h" +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/RuntimeMetrics.h" +#include "velox/common/memory/ArbitrationOperation.h" +#include "velox/common/memory/Memory.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/common/time/Timer.h" + +using facebook::velox::common::testutil::TestValue; + +namespace facebook::velox::memory { +using namespace facebook::velox::memory; + +std::string ArbitrationParticipant::Config::toString() const { + return fmt::format( + "minCapacity {}, fastExponentialGrowthCapacityLimit {}, slowCapacityGrowRatio {}, minFreeCapacity {}, minFreeCapacityRatio {}", + minCapacity, + fastExponentialGrowthCapacityLimit, + slowCapacityGrowRatio, + minFreeCapacity, + minFreeCapacityRatio); +} + +ArbitrationParticipant::Config::Config( + uint64_t _minCapacity, + uint64_t _fastExponentialGrowthCapacityLimit, + double _slowCapacityGrowRatio, + uint64_t _minFreeCapacity, + double _minFreeCapacityRatio) + : minCapacity(_minCapacity), + fastExponentialGrowthCapacityLimit(_fastExponentialGrowthCapacityLimit), + slowCapacityGrowRatio(_slowCapacityGrowRatio), + minFreeCapacity(_minFreeCapacity), + minFreeCapacityRatio(_minFreeCapacityRatio) { + VELOX_CHECK_GE(slowCapacityGrowRatio, 0); + VELOX_CHECK_EQ( + fastExponentialGrowthCapacityLimit == 0, + slowCapacityGrowRatio == 0, + "fastExponentialGrowthCapacityLimit {} and slowCapacityGrowRatio {} " + "both need to be set (non-zero) at the same time to enable growth capacity " + "adjustment.", + fastExponentialGrowthCapacityLimit, + slowCapacityGrowRatio); + + VELOX_CHECK_GE(minFreeCapacityRatio, 0); + VELOX_CHECK_LE(minFreeCapacityRatio, 1); + VELOX_CHECK_EQ( + minFreeCapacity == 0, + minFreeCapacityRatio == 0, + "minFreeCapacity {} and minFreeCapacityRatio {} both " + "need to be set (non-zero) at the same time to enable shrink capacity " + "adjustment.", + minFreeCapacity, + minFreeCapacityRatio); +} + +std::shared_ptr ArbitrationParticipant::create( + uint64_t id, + const std::shared_ptr& pool, + const Config* config) { + return std::shared_ptr( + new ArbitrationParticipant(id, pool, config)); +} + +ArbitrationParticipant::ArbitrationParticipant( + uint64_t id, + const std::shared_ptr& pool, + const Config* config) + : id_(id), + poolWeakPtr_(pool), + pool_(pool.get()), + config_(config), + maxCapacity_(pool_->maxCapacity()), + createTimeUs_(getCurrentTimeMicro()) { + VELOX_CHECK_LE( + config_->minCapacity, + maxCapacity_, + "The min capacity is larger than the max capacity for memory pool {}.", + pool_->name()); +} + +ArbitrationParticipant::~ArbitrationParticipant() { + VELOX_CHECK_NULL(runningOp_); + VELOX_CHECK(waitOps_.empty()); +} + +std::optional ArbitrationParticipant::lock() { + auto sharedPtr = poolWeakPtr_.lock(); + if (sharedPtr == nullptr) { + return {}; + } + return ScopedArbitrationParticipant(shared_from_this(), std::move(sharedPtr)); +} + +uint64_t ArbitrationParticipant::maxGrowCapacity() const { + const auto capacity = pool_->capacity(); + VELOX_CHECK_LE(capacity, maxCapacity_); + return maxCapacity_ - capacity; +} + +uint64_t ArbitrationParticipant::minGrowCapacity() const { + const auto capacity = pool_->capacity(); + if (capacity >= config_->minCapacity) { + return 0; + } + return config_->minCapacity - capacity; +} + +bool ArbitrationParticipant::inactivePool() const { + // Checks if a query memory pool is actively used by query execution or not. + // If not, then we don't have to respect the memory pool min limit or reserved + // capacity check. + // + // NOTE: for query system like Prestissimo, it holds a finished query + // state in minutes for query stats fetch request from the Presto + // coordinator. + return pool_->reservedBytes() == 0 && pool_->peakBytes() != 0; +} + +uint64_t ArbitrationParticipant::reclaimableFreeCapacity() const { + return std::min(maxShrinkCapacity(), maxReclaimableCapacity()); +} + +uint64_t ArbitrationParticipant::maxReclaimableCapacity() const { + if (inactivePool()) { + return pool_->capacity(); + } + const uint64_t capacityBytes = pool_->capacity(); + if (capacityBytes < config_->minCapacity) { + return 0; + } + return capacityBytes - config_->minCapacity; +} + +uint64_t ArbitrationParticipant::reclaimableUsedCapacity() const { + const auto maxReclaimableBytes = maxReclaimableCapacity(); + const auto reclaimableBytes = pool_->reclaimableBytes(); + return std::min(maxReclaimableBytes, reclaimableBytes.value_or(0)); +} + +uint64_t ArbitrationParticipant::maxShrinkCapacity() const { + const uint64_t capacity = pool_->capacity(); + const uint64_t freeBytes = pool_->freeBytes(); + if (config_->minFreeCapacity != 0 && !inactivePool()) { + const uint64_t minFreeBytes = std::min( + static_cast(capacity * config_->minFreeCapacityRatio), + config_->minFreeCapacity); + if (freeBytes <= minFreeBytes) { + return 0; + } else { + return freeBytes - minFreeBytes; + } + } else { + return freeBytes; + } +} + +bool ArbitrationParticipant::checkCapacityGrowth(uint64_t requestBytes) const { + return maxGrowCapacity() >= requestBytes; +} + +void ArbitrationParticipant::getGrowTargets( + uint64_t requestBytes, + uint64_t& maxGrowBytes, + uint64_t& minGrowBytes) const { + const uint64_t capacity = pool_->capacity(); + if (config_->fastExponentialGrowthCapacityLimit == 0 && + config_->slowCapacityGrowRatio == 0) { + maxGrowBytes = requestBytes; + } else { + if (capacity * 2 <= config_->fastExponentialGrowthCapacityLimit) { + maxGrowBytes = capacity; + } else { + maxGrowBytes = capacity * config_->slowCapacityGrowRatio; + } + } + maxGrowBytes = std::max(requestBytes, maxGrowBytes); + minGrowBytes = minGrowCapacity(); + maxGrowBytes = std::max(maxGrowBytes, minGrowBytes); + maxGrowBytes = std::min(maxGrowCapacity(), maxGrowBytes); + + VELOX_CHECK_LE(minGrowBytes, maxGrowBytes); + VELOX_CHECK_LE(requestBytes, maxGrowBytes); +} + +void ArbitrationParticipant::startArbitration(ArbitrationOperation* op) { + ContinueFuture waitPromise{ContinueFuture::makeEmpty()}; + { + std::lock_guard l(stateLock_); + ++numRequests_; + if (runningOp_ != nullptr) { + op->setState(ArbitrationOperation::State::kWaiting); + WaitOp waitOp{ + op, + ContinuePromise{fmt::format( + "Wait for arbitration on {}", op->participant()->name())}}; + waitPromise = waitOp.waitPromise.getSemiFuture(); + waitOps_.emplace_back(std::move(waitOp)); + } else { + runningOp_ = op; + } + } + + TestValue::adjust( + "facebook::velox::memory::ArbitrationParticipant::startArbitration", + this); + + if (waitPromise.valid()) { + uint64_t waitTimeUs{0}; + { + MicrosecondTimer timer(&waitTimeUs); + waitPromise.wait(); + } + op->setLocalArbitrationWaitTimeUs(waitTimeUs); + } +} + +void ArbitrationParticipant::finishArbitration(ArbitrationOperation* op) { + ContinuePromise resumePromise{ContinuePromise::makeEmpty()}; + { + std::lock_guard l(stateLock_); + VELOX_CHECK_EQ(static_cast(op), static_cast(runningOp_)); + if (!waitOps_.empty()) { + resumePromise = std::move(waitOps_.front().waitPromise); + runningOp_ = waitOps_.front().op; + waitOps_.pop_front(); + } else { + runningOp_ = nullptr; + } + } + if (resumePromise.valid()) { + resumePromise.setValue(); + } +} + +uint64_t ArbitrationParticipant::reclaim( + uint64_t targetBytes, + uint64_t maxWaitTimeMs) noexcept { + if (targetBytes == 0) { + return 0; + } + std::lock_guard l(reclaimLock_); + TestValue::adjust( + "facebook::velox::memory::ArbitrationParticipant::reclaim", this); + uint64_t reclaimedBytes{0}; + MemoryReclaimer::Stats reclaimStats; + try { + ++numReclaims_; + pool_->reclaim(targetBytes, maxWaitTimeMs, reclaimStats); + reclaimedBytes = shrink(); + } catch (const std::exception& e) { + VELOX_MEM_LOG(ERROR) << "Failed to reclaim from memory pool " + << pool_->name() << ", aborting it: " << e.what(); + abortLocked(std::current_exception()); + reclaimedBytes = shrink(/*reclaimAll=*/true); + } + return reclaimedBytes; +} + +bool ArbitrationParticipant::grow( + uint64_t growBytes, + uint64_t reservationBytes) { + std::lock_guard l(stateLock_); + ++numGrows_; + const bool success = pool_->grow(growBytes, reservationBytes); + if (success) { + growBytes_ += growBytes; + } + return success; +} + +uint64_t ArbitrationParticipant::shrink(bool reclaimAll) { + std::lock_guard l(stateLock_); + ++numShrinks_; + + uint64_t reclaimedBytes{0}; + if (reclaimAll) { + reclaimedBytes = pool_->shrink(0); + } else { + const uint64_t reclaimTargetBytes = reclaimableFreeCapacity(); + if (reclaimTargetBytes > 0) { + reclaimedBytes = pool_->shrink(reclaimTargetBytes); + } + } + reclaimedBytes_ += reclaimedBytes; + return reclaimedBytes; +} + +uint64_t ArbitrationParticipant::abort( + const std::exception_ptr& error) noexcept { + std::lock_guard l(reclaimLock_); + return abortLocked(error); +} + +uint64_t ArbitrationParticipant::abortLocked( + const std::exception_ptr& error) noexcept { + TestValue::adjust( + "facebook::velox::memory::ArbitrationParticipant::abortLocked", this); + { + std::lock_guard l(stateLock_); + if (aborted_) { + return 0; + } + aborted_ = true; + } + try { + pool_->abort(error); + } catch (const std::exception& e) { + VELOX_MEM_LOG(WARNING) << "Failed to abort memory pool " + << pool_->toString() << ", error: " << e.what(); + } + // NOTE: no matter query memory pool abort throws or not, it should have been + // marked as aborted to prevent any new memory arbitration operations. + VELOX_CHECK(pool_->aborted()); + return shrink(/*reclaimAll=*/true); +} + +bool ArbitrationParticipant::waitForReclaimOrAbort( + uint64_t maxWaitTimeMs) const { + std::unique_lock l( + reclaimLock_, std::chrono::milliseconds(maxWaitTimeMs)); + return l.owns_lock(); +} + +bool ArbitrationParticipant::hasRunningOp() const { + std::lock_guard l(stateLock_); + return runningOp_ != nullptr; +} + +size_t ArbitrationParticipant::numWaitingOps() const { + std::lock_guard l(stateLock_); + return waitOps_.size(); +} + +std::string ArbitrationParticipant::Stats::toString() const { + return fmt::format( + "numRequests: {}, numReclaims: {}, numShrinks: {}, numGrows: {}, reclaimedBytes: {}, growBytes: {}, aborted: {}, duration: {}", + numRequests, + numReclaims, + numShrinks, + numGrows, + succinctBytes(reclaimedBytes), + succinctBytes(growBytes), + aborted, + succinctMicros(durationUs)); +} + +ScopedArbitrationParticipant::ScopedArbitrationParticipant( + std::shared_ptr ArbitrationParticipant, + std::shared_ptr memPool) + : ArbitrationParticipant_(std::move(ArbitrationParticipant)), + pool_(std::move(memPool)) { + VELOX_CHECK_NOT_NULL(ArbitrationParticipant_); + VELOX_CHECK_NOT_NULL(pool_); +} + +ArbitrationCandidate::ArbitrationCandidate( + ScopedArbitrationParticipant&& _participant, + bool freeCapacityOnly) + : participant(std::move(_participant)), + reclaimableUsedCapacity( + freeCapacityOnly ? 0 : participant->reclaimableUsedCapacity()), + reclaimableFreeCapacity(participant->reclaimableFreeCapacity()) {} + +std::string ArbitrationCandidate::toString() const { + return fmt::format( + "{} RECLAIMABLE_USED_CAPACITY {} RECLAIMABLE_FREE_CAPACITY {}", + participant->name(), + succinctBytes(reclaimableUsedCapacity), + succinctBytes(reclaimableFreeCapacity)); +} +} // namespace facebook::velox::memory diff --git a/velox/common/memory/ArbitrationParticipant.h b/velox/common/memory/ArbitrationParticipant.h new file mode 100644 index 000000000000..b014a3c3327f --- /dev/null +++ b/velox/common/memory/ArbitrationParticipant.h @@ -0,0 +1,353 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/common/memory/MemoryArbitrator.h" + +#include "velox/common/base/Counters.h" +#include "velox/common/base/GTestMacros.h" +#include "velox/common/base/StatsReporter.h" +#include "velox/common/future/VeloxPromise.h" +#include "velox/common/memory/Memory.h" + +namespace facebook::velox::memory { + +class ArbitrationOperation; +class ScopedArbitrationParticipant; + +/// Manages the memory arbitration operations on a query memory pool. It also +/// tracks the arbitration stats during the query memory pool's lifecycle. +class ArbitrationParticipant + : public std::enable_shared_from_this { + public: + struct Config { + /// The minimum capacity of a query memory pool. + uint64_t minCapacity; + + /// When growing a query memory pool capacity, the growth bytes will be + /// adjusted in the following way: + /// - If 2 * current capacity is less than or equal to + /// 'fastExponentialGrowthCapacityLimit', grow through fast path by at + /// least doubling the current capacity, when conditions allow (see below + /// NOTE section). + /// - If 2 * current capacity is greater than + /// 'fastExponentialGrowthCapacityLimit', grow through slow path by + /// growing capacity by at least 'slowCapacityGrowRatio' * current + /// capacity if allowed (see below NOTE section). + /// + /// NOTE: if original requested growth bytes is larger than the adjusted + /// growth bytes or adjusted growth bytes reaches max capacity limit, the + /// adjusted growth bytes will not be respected. + /// + /// NOTE: capacity growth adjust is only enabled if both + /// 'fastExponentialGrowthCapacityLimit' and 'slowCapacityGrowRatio' are + /// set, otherwise it is disabled. + uint64_t fastExponentialGrowthCapacityLimit; + double slowCapacityGrowRatio; + + /// When shrinking a memory pool capacity, the shrink bytes will be adjusted + /// in a way such that AFTER shrink, the stricter (whichever is smaller) of + /// the following conditions is met, in order to better fit the query memory + /// pool's current memory usage: + /// - Free capacity is greater or equal to capacity * + /// 'minFreeCapacityRatio' + /// - Free capacity is greater or equal to 'minFreeCapacity' + /// + /// NOTE: in the conditions when original requested shrink bytes ends up + /// with more free capacity than above 2 conditions, the adjusted shrink + /// bytes is not respected. + /// + /// NOTE: capacity shrink adjustment is enabled when both + /// 'minFreeCapacityRatio' and 'minFreeCapacity' are set. + uint64_t minFreeCapacity; + double minFreeCapacityRatio; + + Config( + uint64_t _minCapacity, + uint64_t _fastExponentialGrowthCapacityLimit, + double _slowCapacityGrowRatio, + uint64_t _minFreeCapacity, + double _minFreeCapacityRatio); + + std::string toString() const; + }; + + static std::shared_ptr create( + uint64_t id, + const std::shared_ptr& pool, + const Config* config); + + ~ArbitrationParticipant(); + + /// Returns the query memory pool name of this arbitration participant. + std::string name() const { + return pool_->name(); + } + + /// Returns the id of this arbitration participant assigned by the arbitrator. + /// The id is monotonically increasing and unique across all the alive + /// arbitration participants. + uint64_t id() const { + return id_; + } + + /// Returns the max capacity of the underlying query memory pool. + uint64_t maxCapacity() const { + return maxCapacity_; + } + + /// Returns the min capacity of the underlying query memory pool. + uint64_t minCapacity() const { + return config_->minCapacity; + } + + /// Returns the duration of this arbitration participant since its creation. + uint64_t durationUs() const { + const auto now = getCurrentTimeMicro(); + VELOX_CHECK_GE(now, createTimeUs_); + return now - createTimeUs_; + } + + /// Invoked to acquire a shared reference to this arbitration participant + /// which ensures the liveness of underlying query memory pool. If the query + /// memory pool is being destroyed, then this function returns std::nullopt. + /// + // NOTE: it is not safe to directly access arbitration participant as it only + // holds a weak ptr to the query memory pool. Use 'lock()' to get a scoped + // arbitration participant for access. + std::optional lock(); + + /// Returns the corresponding query memory pool. + MemoryPool* pool() const { + return pool_; + } + + /// Returns the current capacity of the query memory pool. + uint64_t capacity() const { + return pool_->capacity(); + } + + /// Gets the capacity growth targets based on 'requestBytes' and the query + /// memory pool's current capacity. 'maxGrowBytes' is set to allow fast + /// exponential growth when the query memory pool is small and switch to the + /// slow incremental growth after the query memory pool has grown big. + /// 'minGrowBytes' is set to ensure the query memory pool has the minimum + /// capacity and certain headroom free capacity after shrink. Both targets are + /// set to a coarser granularity to reduce the number of unnecessary future + /// memory arbitration requests. The parameters used to set the targets are + /// defined in 'config_'. + void getGrowTargets( + uint64_t requestBytes, + uint64_t& maxGrowBytes, + uint64_t& minGrowBytes) const; + + /// Returns the unused free memory capacity that can be reclaimed from the + /// query memory pool by shrink. + uint64_t reclaimableFreeCapacity() const; + + /// Returns the used memory capacity that can be reclaimed from the query + /// memory pool through disk spilling. + uint64_t reclaimableUsedCapacity() const; + + /// Checks if the query memory pool can grow 'requestBytes' from its current + /// capacity under the max capacity limit. + bool checkCapacityGrowth(uint64_t requestBytes) const; + + /// Invoked to grow the query memory pool capacity by 'growBytes' and commit + /// used reservation by 'reservationBytes'. The function throws if the growth + /// fails. + bool grow(uint64_t growBytes, uint64_t reservationBytes); + + /// Invoked to release the unused memory capacity by reducing its capacity. If + /// 'reclaimAll' is true, the function releases all the unused memory capacity + /// from the query memory pool without regarding to the minimum free capacity + /// restriction. + uint64_t shrink(bool reclaimAll = false); + + // Invoked to reclaim used memory from this memory pool with specified + // 'targetBytes'. The function returns the actually freed capacity. + uint64_t reclaim(uint64_t targetBytes, uint64_t maxWaitTimeMs) noexcept; + + /// Invoked to abort the query memory pool and returns the reclaimed bytes + /// after abort. + uint64_t abort(const std::exception_ptr& error) noexcept; + + /// Returns true if the query memory pool has been aborted. + bool aborted() const { + std::lock_guard l(stateLock_); + return aborted_; + } + + /// Invoked to wait for the pending memory reclaim or abort operation to + /// complete within a 'maxWaitTimeMs' time window. The function returns false + /// if the wait has timed out. + bool waitForReclaimOrAbort(uint64_t maxWaitTimeMs) const; + + /// Invoked to start arbitration operation 'op'. The operation needs to wait + /// for the prior arbitration operations to finish first before executing to + /// ensure the serialized execution of arbitration operations from the same + /// query memory pool. + void startArbitration(ArbitrationOperation* op); + + /// Invoked by a finished arbitration operation 'op' to kick off the next + /// waiting operation to start execution if there is one. + void finishArbitration(ArbitrationOperation* op); + + /// Returns true if there is a running arbitration operation on this + /// participant. + bool hasRunningOp() const; + + /// Returns the number of waiting arbitration operations on this participant. + size_t numWaitingOps() const; + + struct Stats { + uint64_t durationUs{0}; + uint32_t numRequests{0}; + uint32_t numReclaims{0}; + uint32_t numShrinks{0}; + uint32_t numGrows{0}; + uint64_t reclaimedBytes{0}; + uint64_t growBytes{0}; + bool aborted{false}; + + std::string toString() const; + }; + + Stats stats() const { + Stats stats; + stats.durationUs = durationUs(); + stats.aborted = aborted_; + stats.numRequests = numRequests_; + stats.numGrows = numGrows_; + stats.numShrinks = numShrinks_; + stats.numReclaims = numReclaims_; + stats.reclaimedBytes = reclaimedBytes_; + stats.growBytes = growBytes_; + return stats; + } + + private: + ArbitrationParticipant( + uint64_t id, + const std::shared_ptr& pool, + const Config* config); + + // Indicates if the query memory pool is actively used by a query execution or + // not. + bool inactivePool() const; + + // Returns the max capacity to reclaim from the query memory pool assuming all + // the query memory is reclaimable. + uint64_t maxReclaimableCapacity() const; + + // Returns the max capacity to shrink from the query memory pool. It ensures + // the memory pool having headroom free capacity after shrink as specified by + // 'minFreeCapacityRatio' and 'minFreeCapacity' in 'config_'. This helps to + // reduce the number of unnecessary memory arbitration requests. + uint64_t maxShrinkCapacity() const; + + // Returns the max capacity to grow of the query memory pool as specified by + // 'fastExponentialGrowthCapacityLimit' and 'slowCapacityGrowRatio' in + // 'config_'. + uint64_t maxGrowCapacity() const; + + // Returns the min capacity to grow the query memory pool to have the minnimal + // capacity as specified by 'minCapacity' in 'config_'. + uint64_t minGrowCapacity() const; + + // Aborts the query memory pool and returns the reclaimed bytes after abort. + uint64_t abortLocked(const std::exception_ptr& error) noexcept; + + const uint64_t id_; + const std::weak_ptr poolWeakPtr_; + MemoryPool* const pool_; + const Config* const config_; + const uint64_t maxCapacity_; + const size_t createTimeUs_; + + mutable std::mutex stateLock_; + bool aborted_{false}; + + // Points to the current running arbitration operation on this participant. + ArbitrationOperation* runningOp_{nullptr}; + + struct WaitOp { + ArbitrationOperation* op; + ContinuePromise waitPromise; + }; + /// The resume promises of the arbitration operations on this participant + /// waiting for serial execution. + std::deque waitOps_; + + tsan_atomic numRequests_{0}; + tsan_atomic numReclaims_{0}; + tsan_atomic numShrinks_{0}; + tsan_atomic numGrows_{0}; + tsan_atomic reclaimedBytes_{0}; + tsan_atomic growBytes_{0}; + + mutable std::timed_mutex reclaimLock_; + + friend class ScopedArbitrationParticipant; +}; + +/// The wrapper of the arbitration participant which holds a shared reference to +/// the query memory pool to ensure its liveness during memory arbitration +/// execution. +class ScopedArbitrationParticipant { + public: + ScopedArbitrationParticipant( + std::shared_ptr ArbitrationParticipant, + std::shared_ptr pool); + + ArbitrationParticipant* operator->() const { + return ArbitrationParticipant_.get(); + } + + ArbitrationParticipant& operator*() const { + return *ArbitrationParticipant_; + } + + ArbitrationParticipant& operator()() const { + return *ArbitrationParticipant_; + } + + ArbitrationParticipant* get() const { + return ArbitrationParticipant_.get(); + } + + private: + std::shared_ptr ArbitrationParticipant_; + std::shared_ptr pool_; +}; + +/// The candidate participant stats used by arbitrator to make arbitration +/// decisions. +struct ArbitrationCandidate { + ScopedArbitrationParticipant participant; + int64_t reclaimableUsedCapacity{0}; + int64_t reclaimableFreeCapacity{0}; + + /// If 'freeCapacityOnly' is true, the candidate is only used to reclaim free + /// capacity so only collects the free capacity stats. + ArbitrationCandidate( + ScopedArbitrationParticipant&& _participant, + bool freeCapacityOnly); + + std::string toString() const; +}; +} // namespace facebook::velox::memory diff --git a/velox/common/memory/MemoryPool.h b/velox/common/memory/MemoryPool.h index 1f3b33ed3a90..985c21d2f91e 100644 --- a/velox/common/memory/MemoryPool.h +++ b/velox/common/memory/MemoryPool.h @@ -39,6 +39,10 @@ namespace facebook::velox::exec { class ParallelMemoryReclaimer; } +namespace facebook::velox::memory { +class TestArbitrator; +} + namespace facebook::velox::memory { #define VELOX_MEM_POOL_CAP_EXCEEDED(errorMessage) \ _VELOX_THROW( \ @@ -558,7 +562,9 @@ class MemoryPool : public std::enable_shared_from_this { friend class velox::exec::ParallelMemoryReclaimer; friend class MemoryManager; friend class MemoryArbitrator; + friend class velox::memory::TestArbitrator; friend class ScopedMemoryPoolArbitrationCtx; + friend class ArbitrationParticipant; VELOX_FRIEND_TEST(MemoryPoolTest, shrinkAndGrowAPIs); VELOX_FRIEND_TEST(MemoryPoolTest, grow); diff --git a/velox/common/memory/tests/ArbitrationParticipantTest.cpp b/velox/common/memory/tests/ArbitrationParticipantTest.cpp new file mode 100644 index 000000000000..d546d7f6647d --- /dev/null +++ b/velox/common/memory/tests/ArbitrationParticipantTest.cpp @@ -0,0 +1,1790 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include + +#include "folly/experimental/EventCount.h" +#include "folly/futures/Barrier.h" + +#include "gmock/gmock-matchers.h" +#include "velox/common/base/SuccinctPrinter.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/memory/ArbitrationOperation.h" +#include "velox/common/memory/ArbitrationParticipant.h" +#include "velox/common/memory/MallocAllocator.h" +#include "velox/common/memory/Memory.h" +#include "velox/common/memory/MemoryArbitrator.h" +#include "velox/common/memory/SharedArbitrator.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/exec/OperatorUtils.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" + +DECLARE_bool(velox_memory_leak_check_enabled); +DECLARE_bool(velox_suppress_memory_capacity_exceeding_error_message); + +using namespace ::testing; +using namespace facebook::velox::common::testutil; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace facebook::velox::memory { +static const std::string arbitratorKind("TEST"); + +class TestArbitrator : public MemoryArbitrator { + public: + explicit TestArbitrator(const Config& config) + : MemoryArbitrator( + {.kind = config.kind, + .capacity = config.capacity, + .extraConfigs = config.extraConfigs}) {} + + void addPool(const std::shared_ptr& /*unused*/) override {} + + void removePool(MemoryPool* /*unused*/) override {} + + bool growCapacity(MemoryPool* memoryPool, uint64_t requestBytes) override { + VELOX_CHECK_LE( + memoryPool->capacity() + requestBytes, memoryPool->maxCapacity()); + memoryPool->grow(requestBytes, requestBytes); + return true; + } + + uint64_t shrinkCapacity(uint64_t /*unused*/, bool /*unused*/, bool /*unused*/) + override { + VELOX_NYI(); + } + + uint64_t shrinkCapacity(MemoryPool* /*unused*/, uint64_t /*unused*/) + override { + VELOX_NYI(); + } + + Stats stats() const override { + VELOX_NYI(); + } + + std::string toString() const override { + VELOX_NYI(); + } + + std::string kind() const override { + return arbitratorKind; + } +}; + +namespace { +constexpr int64_t KB = 1024L; +constexpr int64_t MB = 1024L * KB; + +constexpr uint64_t kMemoryCapacity = 512 * MB; +constexpr uint64_t kMemoryPoolReservedCapacity = 64 * MB; +constexpr uint64_t kMemoryPoolMinFreeCapacity = 32 * MB; +constexpr double kMemoryPoolMinFreeCapacityRatio = 0.25; +constexpr uint64_t kFastExponentialGrowthCapacityLimit = 256 * MB; +constexpr double kSlowCapacityGrowRatio = 0.25; + +class MemoryReclaimer; + +using ReclaimInjectionCallback = + std::function; +using ArbitrationInjectionCallback = std::function; + +struct Allocation { + void* buffer{nullptr}; + size_t size{0}; +}; + +class MockTask : public std::enable_shared_from_this { + public: + MockTask(MemoryManager* manager, uint64_t capacity) + : root_(manager->addRootPool( + fmt::format("TaskPool-{}", taskId_++), + capacity)), + pool_(root_->addLeafChild("MockOperator")) {} + + ~MockTask() { + free(); + } + + class RootMemoryReclaimer : public memory::MemoryReclaimer { + public: + RootMemoryReclaimer(const std::shared_ptr& task) : task_(task) {} + + static std::unique_ptr create( + const std::shared_ptr& task) { + return std::make_unique(task); + } + + bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) + const override { + auto task = task_.lock(); + if (task == nullptr) { + return false; + } + return memory::MemoryReclaimer::reclaimableBytes(pool, reclaimableBytes); + } + + uint64_t reclaim( + MemoryPool* pool, + uint64_t targetBytes, + uint64_t maxWaitMs, + Stats& stats) override { + auto task = task_.lock(); + if (task == nullptr) { + return 0; + } + return memory::MemoryReclaimer::reclaim( + pool, targetBytes, maxWaitMs, stats); + } + + void abort(MemoryPool* pool, const std::exception_ptr& error) override { + auto task = task_.lock(); + if (task == nullptr) { + return; + } + memory::MemoryReclaimer::abort(pool, error); + } + + private: + std::weak_ptr task_; + }; + + class LeafMemoryReclaimer : public memory::MemoryReclaimer { + public: + LeafMemoryReclaimer( + std::shared_ptr task, + bool reclaimable, + ReclaimInjectionCallback reclaimInjectCb = nullptr, + ArbitrationInjectionCallback arbitrationInjectCb = nullptr) + : task_(task), + reclaimable_(reclaimable), + reclaimInjectCb_(std::move(reclaimInjectCb)), + arbitrationInjectCb_(std::move(arbitrationInjectCb)) {} + + bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) + const override { + if (!reclaimable_) { + return false; + } + std::shared_ptr task = task_.lock(); + VELOX_CHECK_NOT_NULL(task); + return task->reclaimableBytes(pool, reclaimableBytes); + } + + uint64_t reclaim( + MemoryPool* pool, + uint64_t targetBytes, + uint64_t /*unused*/, + Stats& stats) override { + if (!reclaimable_) { + return 0; + } + if (reclaimInjectCb_ != nullptr) { + reclaimInjectCb_(pool, targetBytes); + } + std::shared_ptr task = task_.lock(); + VELOX_CHECK_NOT_NULL(task); + const auto reclaimBytes = task->reclaim(pool, targetBytes); + stats.reclaimedBytes += reclaimBytes; + return reclaimBytes; + } + + void abort(MemoryPool* pool, const std::exception_ptr& error) override { + std::shared_ptr task = task_.lock(); + VELOX_CHECK_NOT_NULL(task); + task->abort(pool, error); + } + + private: + std::weak_ptr task_; + const bool reclaimable_; + const ReclaimInjectionCallback reclaimInjectCb_; + const ArbitrationInjectionCallback arbitrationInjectCb_; + + std::exception_ptr abortError_; + }; + + void setMemoryReclaimers( + bool reclaimable, + ReclaimInjectionCallback reclaimInjectCb, + ArbitrationInjectionCallback arbitrationInjectCb) { + root_->setReclaimer(RootMemoryReclaimer::create(shared_from_this())); + pool_->setReclaimer(std::make_unique( + shared_from_this(), + reclaimable, + std::move(reclaimInjectCb), + std::move(arbitrationInjectCb))); + } + + const std::shared_ptr& pool() const { + return root_; + } + + std::exception_ptr abortError() const { + return abortError_; + } + + uint64_t capacity() const { + return root_->capacity(); + } + + void* allocate(uint64_t bytes) { + VELOX_CHECK_EQ(bytes % pool_->alignment(), 0); + + void* buffer = pool_->allocate(bytes); + std::lock_guard l(mu_); + totalBytes_ += bytes; + allocations_.emplace(buffer, bytes); + VELOX_CHECK_EQ(allocations_.count(buffer), 1); + return buffer; + } + + void free(void* buffer) { + size_t size{0}; + { + std::lock_guard l(mu_); + VELOX_CHECK_EQ(allocations_.count(buffer), 1); + size = allocations_[buffer]; + totalBytes_ -= size; + allocations_.erase(buffer); + } + pool_->free(buffer, size); + } + + bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) + const { + std::lock_guard l(mu_); + VELOX_CHECK_EQ(pool.name(), pool_->name()); + reclaimableBytes = totalBytes_; + return true; + } + + uint64_t reclaim(MemoryPool* pool, uint64_t targetBytes) { + VELOX_CHECK_GT(targetBytes, 0); + ++numReclaims_; + reclaimTargetBytes_.push_back(targetBytes); + uint64_t bytesReclaimed{0}; + std::vector allocationsToFree; + { + std::lock_guard l(mu_); + VELOX_CHECK_NOT_NULL(pool_); + VELOX_CHECK_EQ(pool->name(), pool_->name()); + + auto allocIt = allocations_.begin(); + while (allocIt != allocations_.end() && + ((targetBytes != 0) && (bytesReclaimed < targetBytes))) { + allocationsToFree.push_back({allocIt->first, allocIt->second}); + bytesReclaimed += allocIt->second; + allocIt = allocations_.erase(allocIt); + } + totalBytes_ -= bytesReclaimed; + } + for (const auto& allocation : allocationsToFree) { + pool_->free(allocation.buffer, allocation.size); + } + return bytesReclaimed; + } + + void abort(MemoryPool* pool, const std::exception_ptr& error) { + ++numAborts_; + abortError_ = error; + free(); + } + + struct Stats { + uint64_t numReclaims; + uint64_t numAborts; + std::vector reclaimTargetBytes; + }; + + Stats stats() const { + Stats stats; + stats.numReclaims = numReclaims_; + stats.reclaimTargetBytes = reclaimTargetBytes_; + stats.numAborts = numAborts_; + return stats; + } + + private: + void free() { + std::unordered_map allocationsToFree; + { + std::lock_guard l(mu_); + for (auto entry : allocations_) { + totalBytes_ -= entry.second; + } + VELOX_CHECK_EQ(totalBytes_, 0); + allocationsToFree.swap(allocations_); + } + for (auto entry : allocationsToFree) { + pool_->free(entry.first, entry.second); + } + } + + inline static std::atomic_int taskId_{0}; + + const std::shared_ptr root_; + const std::shared_ptr pool_; + + mutable std::mutex mu_; + uint64_t totalBytes_{0}; + + std::unordered_map allocations_; + std::atomic_uint64_t numReclaims_{0}; + std::atomic_uint64_t numAborts_{0}; + std::vector reclaimTargetBytes_; + std::exception_ptr abortError_{nullptr}; +}; + +class ArbitrationParticipantTest : public testing::Test { + protected: + static void SetUpTestCase() { + SharedArbitrator::registerFactory(); + FLAGS_velox_memory_leak_check_enabled = true; + TestValue::enable(); + MemoryArbitrator::Factory factory = + [](const MemoryArbitrator::Config& config) { + return std::make_unique(config); + }; + MemoryArbitrator::registerFactory(arbitratorKind, factory); + } + + void SetUp() override { + setupMemory(); + } + + void TearDown() override {} + + void setupMemory(int64_t memoryCapacity = kMemoryCapacity) { + MemoryManagerOptions options; + options.allocatorCapacity = memoryCapacity; + options.arbitratorReservedCapacity = 0; + options.arbitratorKind = arbitratorKind; + options.checkUsageLeak = true; + manager_ = std::make_unique(options); + } + + std::shared_ptr createTask( + int64_t capacity = 0, + bool reclaimable = true, + ReclaimInjectionCallback reclaimInjectCb = nullptr, + ArbitrationInjectionCallback arbitrationInjectCb = nullptr) { + if (capacity == 0) { + capacity = manager_->capacity(); + } + auto task = std::make_shared(manager_.get(), capacity); + task->setMemoryReclaimers( + reclaimable, reclaimInjectCb, arbitrationInjectCb); + return task; + } + + std::unique_ptr manager_; + std::unique_ptr executor_ = + std::make_unique(4); +}; + +static ArbitrationParticipant::Config arbitrationConfig( + uint64_t minCapacity = kMemoryPoolReservedCapacity, + uint64_t fastExponentialGrowthCapacityLimit = + kFastExponentialGrowthCapacityLimit, + double slowCapacityGrowRatio = kSlowCapacityGrowRatio, + uint64_t minFreeCapacity = kMemoryPoolMinFreeCapacity, + double minFreeCapacityRatio = kMemoryPoolMinFreeCapacityRatio) { + return ArbitrationParticipant::Config{ + minCapacity, + fastExponentialGrowthCapacityLimit, + slowCapacityGrowRatio, + minFreeCapacity, + minFreeCapacityRatio}; +} + +TEST_F(ArbitrationParticipantTest, config) { + struct { + uint64_t minCapacity; + uint64_t fastExponentialGrowthCapacityLimit; + double slowCapacityGrowRatio; + uint64_t minFreeCapacity; + double minFreeCapacityRatio; + bool expectedError; + std::string expectedToString; + + std::string debugString() const { + return fmt::format( + "minCapacity {}, fastExponentialGrowthCapacityLimit: {}, slowCapacityGrowRatio: {}, minFreeCapacity: {}, minFreeCapacityRatio: {}, expectedError: {}, expectedToString: {}", + succinctBytes(minCapacity), + succinctBytes(fastExponentialGrowthCapacityLimit), + slowCapacityGrowRatio, + succinctBytes(minFreeCapacity), + minFreeCapacityRatio, + expectedError, + expectedToString); + } + } testSettings[] = { + {1, + 1, + 0.1, + 1, + 0.1, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 1, slowCapacityGrowRatio 0.1, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {1, + 0, + 0, + 1, + 0.1, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {1, + 0, + 0, + 0, + 0, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 0, minFreeCapacityRatio 0"}, + {1, + 0, + 0, + 1, + 0.1, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {0, + 1, + 0.1, + 1, + 0.1, + false, + "minCapacity 0, fastExponentialGrowthCapacityLimit 1, slowCapacityGrowRatio 0.1, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {0, + 0, + 0, + 1, + 0.1, + false, + "minCapacity 0, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {0, + 0, + 0, + 0, + 0, + false, + "minCapacity 0, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 0, minFreeCapacityRatio 0"}, + {0, + 0, + 0, + 1, + 0.1, + false, + "minCapacity 0, fastExponentialGrowthCapacityLimit 0, slowCapacityGrowRatio 0, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {1, 0, 0.1, 1, 0.1, true, ""}, + {1, 1, 0.1, 0, 0.1, true, ""}, + {1, 1, 0.1, 1, 0, true, ""}, + {1, + 1, + 2, + 1, + 0.1, + false, + "minCapacity 1, fastExponentialGrowthCapacityLimit 1, slowCapacityGrowRatio 2, minFreeCapacity 1, minFreeCapacityRatio 0.1"}, + {1, 1, -1, 1, 0.1, true, ""}, + {1, 1, 0.1, 1, 2, true, ""}, + {1, 1, 0.1, 1, -1, true, ""}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + if (testData.expectedError) { + VELOX_ASSERT_THROW( + ArbitrationParticipant::Config( + testData.minCapacity, + testData.fastExponentialGrowthCapacityLimit, + testData.slowCapacityGrowRatio, + testData.minFreeCapacity, + testData.minFreeCapacityRatio), + ""); + continue; + } + const auto config = ArbitrationParticipant::Config( + testData.minCapacity, + testData.fastExponentialGrowthCapacityLimit, + testData.slowCapacityGrowRatio, + testData.minFreeCapacity, + testData.minFreeCapacityRatio); + ASSERT_EQ(testData.minCapacity, config.minCapacity); + ASSERT_EQ( + testData.fastExponentialGrowthCapacityLimit, + config.fastExponentialGrowthCapacityLimit); + ASSERT_EQ(testData.slowCapacityGrowRatio, config.slowCapacityGrowRatio); + ASSERT_EQ(testData.minFreeCapacity, config.minFreeCapacity); + ASSERT_EQ(testData.minFreeCapacityRatio, config.minFreeCapacityRatio); + ASSERT_EQ(config.toString(), testData.expectedToString); + } +} + +TEST_F(ArbitrationParticipantTest, constructor) { + auto task = createTask(); + const auto config = arbitrationConfig(); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + ASSERT_EQ(participant->id(), 10); + ASSERT_EQ(participant->name(), task->pool()->name()); + ASSERT_EQ(participant->pool(), task->pool().get()); + ASSERT_EQ(participant->maxCapacity(), kMemoryCapacity); + ASSERT_EQ(participant->minCapacity(), kMemoryPoolReservedCapacity); + ASSERT_EQ(participant->capacity(), 0); + ASSERT_FALSE(participant->hasRunningOp()); + ASSERT_EQ(participant->numWaitingOps(), 0); + ASSERT_THAT( + participant->stats().toString(), + ::testing::StartsWith( + "numRequests: 0, numReclaims: 0, numShrinks: 0, numGrows: 0, reclaimedBytes: 0B, growBytes: 0B, aborted: false")); + + { + auto scopedParticipant = participant->lock().value(); + ASSERT_EQ(scopedParticipant->id(), 10); + ASSERT_EQ(scopedParticipant->name(), task->pool()->name()); + ASSERT_EQ(scopedParticipant->pool(), task->pool().get()); + ASSERT_EQ(scopedParticipant->maxCapacity(), kMemoryCapacity); + ASSERT_EQ(scopedParticipant->minCapacity(), kMemoryPoolReservedCapacity); + ASSERT_EQ(scopedParticipant->capacity(), 0); + + task.reset(); + + ASSERT_EQ(scopedParticipant->capacity(), 0); + ASSERT_EQ(participant->capacity(), 0); + } + + ASSERT_FALSE(participant->lock().has_value()); +} + +TEST_F(ArbitrationParticipantTest, getGrowTargets) { + struct { + uint64_t minCapacity; + uint64_t fastExponentialGrowthCapacityLimit; + double slowCapacityGrowRatio; + uint64_t capacity; + uint64_t requestBytes; + uint64_t expectedMaxGrowTarget; + uint64_t expectedMinGrowTarget; + + std::string debugString() const { + return fmt::format( + "minCapacity {}, fastExponentialGrowthCapacityLimit {}, slowCapacityGrowRatio {}, capacity {}, requestBytes {}, expectedMaxGrowTarget {}, expectedMinGrowTarget {}", + succinctBytes(minCapacity), + succinctBytes(fastExponentialGrowthCapacityLimit), + slowCapacityGrowRatio, + succinctBytes(capacity), + succinctBytes(requestBytes), + succinctBytes(expectedMaxGrowTarget), + succinctBytes(expectedMinGrowTarget)); + } + } testSettings[] = { + // Without exponential growth. + {0, 0, 0.0, 0, 1 << 20, 1 << 20, 0}, + {0, 0, 0.0, 32 << 20, 1 << 20, 1 << 20, 0}, + {0, 0, 0.0, 32 << 20, 32 << 20, 32 << 20, 0}, + // Fast growth. + {0, 16 << 20, 1.0, 0, 1 << 20, 1 << 20, 0}, + {0, 16 << 20, 1.0, 1 << 20, 1 << 20, 1 << 20, 0}, + {0, 16 << 20, 1.0, 2 << 20, 1 << 20, 2 << 20, 0}, + {0, 16 << 20, 1.0, 4 << 20, 1 << 20, 4 << 20, 0}, + {0, 16 << 20, 1.0, 8 << 20, 1 << 20, 8 << 20, 0}, + {0, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 0}, + {0, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 0}, + {0, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 0}, + // Slow growth. + {0, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 0}, + {0, 16 << 20, 2.0, 24 << 20, 1 << 20, 48 << 20, 0}, + {0, 16 << 20, 100.0, 24 << 20, 1 << 20, kMemoryCapacity - (24 << 20), 0}, + {0, 16 << 20, 0.1, 24 << 20, 1 << 20, uint64_t((24 << 20) * 0.1), 0}, + // With min capacity. + // Without exponential growth. + {4 << 20, 0, 0.0, 0, 1 << 20, 4 << 20, 4 << 20}, + {4 << 20, 0, 0.0, 32 << 20, 1 << 20, 1 << 20, 0}, + {4 << 20, 0, 0.0, 32 << 20, 32 << 20, 32 << 20, 0}, + {64 << 20, 0, 0.0, 32 << 20, 1 << 20, 32 << 20, 32 << 20}, + {64 << 20, 0, 0.0, 32 << 20, 32 << 20, 32 << 20, 32 << 20}, + {48 << 20, 0, 0.0, 32 << 20, 32 << 20, 32 << 20, 16 << 20}, + // Fast growth. + {1 << 20, 16 << 20, 1.0, 0, 1 << 20, 1 << 20, 1 << 20}, + {1 << 20, 16 << 20, 1.0, 0, 2 << 20, 2 << 20, 1 << 20}, + {4 << 20, 16 << 20, 1.0, 0, 1 << 20, 4 << 20, 4 << 20}, + {1 << 20, 16 << 20, 1.0, 1 << 20, 1 << 20, 1 << 20, 0}, + {1 << 20, 16 << 20, 1.0, 2 << 20, 1 << 20, 2 << 20, 0}, + {2 << 20, 16 << 20, 1.0, 2 << 20, 1 << 20, 2 << 20, 0}, + {4 << 20, 16 << 20, 1.0, 2 << 20, 1 << 20, 2 << 20, 2 << 20}, + {8 << 20, 16 << 20, 1.0, 2 << 20, 1 << 20, 6 << 20, 6 << 20}, + {3 << 20, 16 << 20, 1.0, 4 << 20, 1 << 20, 4 << 20, 0}, + {4 << 20, 16 << 20, 1.0, 4 << 20, 1 << 20, 4 << 20, 0}, + {5 << 20, 16 << 20, 1.0, 4 << 20, 1 << 20, 4 << 20, 1 << 20}, + {1 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 0}, + {12 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 0}, + {13 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 1 << 20}, + {24 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 12 << 20, 12 << 20}, + {25 << 20, 16 << 20, 1.0, 12 << 20, 1 << 20, 13 << 20, 13 << 20}, + {1 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 0}, + {16 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 0}, + {17 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 1 << 20}, + {32 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 16 << 20, 16 << 20}, + {64 << 20, 16 << 20, 1.0, 16 << 20, 1 << 20, 48 << 20, 48 << 20}, + {1 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 0}, + {12 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 0}, + {13 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 1 << 20}, + {23 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 11 << 20}, + {35 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 23 << 20, 23 << 20}, + {48 << 20, 16 << 20, 1.0, 12 << 20, 23 << 20, 36 << 20, 36 << 20}, + // Slow growth. + {1 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 0}, + {24 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 0}, + {25 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 1 << 20}, + {48 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 24 << 20}, + {47 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 24 << 20, 23 << 20}, + {64 << 20, 16 << 20, 1.0, 24 << 20, 1 << 20, 40 << 20, 40 << 20}, + {1 << 20, 16 << 20, 2.0, 24 << 20, 1 << 20, 48 << 20, 0}, + {36 << 20, 16 << 20, 2.0, 24 << 20, 1 << 20, 48 << 20, 12 << 20}, + {72 << 20, 16 << 20, 2.0, 24 << 20, 1 << 20, 48 << 20, 48 << 20}, + {96 << 20, 16 << 20, 2.0, 24 << 20, 1 << 20, 72 << 20, 72 << 20}, + {1 << 20, + 16 << 20, + 100.0, + 24 << 20, + 1 << 20, + kMemoryCapacity - (24 << 20), + 0}, + {36 << 20, + 16 << 20, + 100.0, + 24 << 20, + 1 << 20, + kMemoryCapacity - (24 << 20), + 12 << 20}, + {1 << 20, + 16 << 20, + 0.1, + 24 << 20, + 1 << 20, + uint64_t((24 << 20) * 0.1), + 0}, + {24 << 20, + 16 << 20, + 0.1, + 24 << 20, + 1 << 20, + uint64_t((24 << 20) * 0.1), + 0}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig( + testData.minCapacity, + testData.fastExponentialGrowthCapacityLimit, + testData.slowCapacityGrowRatio); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimFromAll=*/true); + ASSERT_EQ(scopedParticipant->capacity(), 0); + void* buffer = task->allocate(testData.capacity); + SCOPE_EXIT { + task->free(buffer); + }; + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + uint64_t maxGrowBytes{0}; + uint64_t minGrowBytes{0}; + scopedParticipant->getGrowTargets( + testData.requestBytes, maxGrowBytes, minGrowBytes); + ASSERT_EQ(maxGrowBytes, testData.expectedMaxGrowTarget); + ASSERT_EQ(minGrowBytes, testData.expectedMinGrowTarget); + + // Test operation corresponding API. + ArbitrationOperation op( + std::move(scopedParticipant), testData.requestBytes, 1 << 30); + op.start(); + ASSERT_EQ(op.maxGrowBytes(), 0); + ASSERT_EQ(op.minGrowBytes(), 0); + ASSERT_EQ(op.requestBytes(), testData.requestBytes); + op.setGrowTargets(); + ASSERT_EQ(op.requestBytes(), testData.requestBytes); + ASSERT_EQ(op.maxGrowBytes(), testData.expectedMaxGrowTarget); + ASSERT_EQ(op.minGrowBytes(), testData.expectedMinGrowTarget); + // Can't set grow targets twice. + VELOX_ASSERT_THROW( + op.setGrowTargets(), + "Arbitration operation grow targets have already been set"); + op.finish(); + } +} + +TEST_F(ArbitrationParticipantTest, reclaimableFreeCapacityAndShrink) { + struct { + uint64_t minCapacity; + uint64_t minFreeCapacity; + double minFreeCapacityRatio; + uint64_t capacity; + uint64_t usedBytes; + uint64_t peakBytes; + uint64_t expectedFreeCapacity; + + std::string debugString() const { + return fmt::format( + "minCapacity {}, minFreeCapacity {}, minFreeCapacityRatio {}, capacity {}, usedBytes {}, peakBytes {}, expectedFreeCapacity {}", + succinctBytes(minCapacity), + succinctBytes(minFreeCapacity), + minFreeCapacityRatio, + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(peakBytes), + succinctBytes(expectedFreeCapacity)); + } + } testSettings[] = { + {128 << 20, 0, 0.0, 128 << 20, 0, 0, 0}, + {128 << 20, 0, 0.0, 128 << 20, 0, 32 << 20, 128 << 20}, + {128 << 20, 0, 0.0, 128 << 20, 32 << 20, 0, 0}, + {128 << 20, 0, 0.0, 128 << 20, 128 << 20, 0, 0}, + {128 << 20, 0, 0.0, 256 << 20, 256 << 20, 0, 0}, + {128 << 20, 0, 0.0, 256 << 20, 200 << 20, 0, 56 << 20}, + {128 << 20, 0, 0.0, 256 << 20, 32 << 20, 0, 128 << 20}, + {0, 32 << 20, 0.25, 128 << 20, 0, 0, 96 << 20}, + {0, 64 << 20, 0.25, 128 << 20, 0, 0, 96 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 0, 0, 224 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 0, 64 << 20, 256 << 20}, + {0, 32 << 20, 0.25, 128 << 20, 64 << 20, 0, 32 << 20}, + {0, 32 << 20, 0.25, 128 << 20, 96 << 20, 0, 0}, + {0, 32 << 20, 0.25, 128 << 20, 72 << 20, 0, 24 << 20}, + {0, 64 << 20, 0.25, 128 << 20, 64 << 20, 0, 32 << 20}, + {0, 64 << 20, 0.25, 128 << 20, 96 << 20, 0, 0}, + {0, 64 << 20, 0.25, 128 << 20, 72 << 20, 0, 24 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 64 << 20, 0, 160 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 96 << 20, 0, 128 << 20}, + {0, 32 << 20, 0.25, 256 << 20, 224 << 20, 0, 0}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 0, 0, 0}, + {64 << 20, 32 << 20, 0.25, 128 << 20, 0, 0, 64 << 20}, + {96 << 20, 32 << 20, 0.25, 128 << 20, 0, 0, 32 << 20}, + {64 << 20, 64 << 20, 0.25, 128 << 20, 0, 0, 64 << 20}, + {64 << 20, 32 << 20, 0.5, 128 << 20, 0, 0, 64 << 20}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 0, 32 << 20, 128 << 20}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 64 << 20, 64 << 20, 0}, + {64 << 20, 32 << 20, 0.25, 128 << 20, 64 << 20, 64 << 20, 32 << 20}, + {96 << 20, 32 << 20, 0.25, 128 << 20, 64 << 20, 64 << 20, 32 << 20}, + {96 << 20, 32 << 20, 0.25, 256 << 20, 64 << 20, 64 << 20, 160 << 20}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + for (bool reclaimAll : {false, true}) { + SCOPED_TRACE(fmt::format("reclaimAll {}", reclaimAll)); + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig( + testData.minCapacity, + 0, + 0.0, + testData.minFreeCapacity, + testData.minFreeCapacityRatio); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + ASSERT_EQ(scopedParticipant->stats().numShrinks, 0); + if (testData.peakBytes > 0) { + void* buffer = task->allocate(testData.peakBytes); + task->free(buffer); + ASSERT_EQ(scopedParticipant->pool()->peakBytes(), testData.peakBytes); + } + + scopedParticipant->shrink(/*reclaimFromAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + void* buffer{nullptr}; + if (testData.usedBytes > 0) { + buffer = task->allocate(testData.usedBytes); + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + ASSERT_EQ( + scopedParticipant->reclaimableFreeCapacity(), + testData.expectedFreeCapacity); + + const uint64_t prevReclaimedBytes = + scopedParticipant->stats().reclaimedBytes; + if (reclaimAll) { + ASSERT_EQ( + scopedParticipant->shrink(reclaimAll), + testData.capacity - testData.usedBytes); + ASSERT_EQ( + scopedParticipant->stats().reclaimedBytes, + prevReclaimedBytes + testData.capacity - testData.usedBytes); + } else { + ASSERT_EQ( + scopedParticipant->shrink(reclaimAll), + testData.expectedFreeCapacity); + ASSERT_EQ( + scopedParticipant->stats().reclaimedBytes, + prevReclaimedBytes + testData.expectedFreeCapacity); + } + ASSERT_EQ(scopedParticipant->stats().numShrinks, 2); + ASSERT_EQ(scopedParticipant->stats().numReclaims, 0); + ASSERT_EQ(scopedParticipant->stats().numGrows, 1); + ASSERT_GE(scopedParticipant->stats().durationUs, 0); + ASSERT_FALSE(scopedParticipant->stats().aborted); + + if (buffer != nullptr) { + task->free(buffer); + } + } + } +} + +TEST_F(ArbitrationParticipantTest, reclaimableUsedCapacityAndReclaim) { + struct { + uint64_t minCapacity; + uint64_t minFreeCapacity; + double minFreeCapacityRatio; + uint64_t capacity; + uint64_t usedBytes; + uint64_t peakBytes; + uint64_t expectedReclaimableUsedBytes; + uint64_t expectedActualReclaimedBytes; + uint64_t expectedUsedBytes; + + std::string debugString() const { + return fmt::format( + "minCapacity {}, minFreeCapacity {}, minFreeCapacityRatio {}, capacity {}, usedBytes {}, peakBytes {}, expectedReclaimableUsedBytes {}, expectedActualReclaimedBytes {}, expectedUsedBytes {}", + succinctBytes(minCapacity), + succinctBytes(minFreeCapacity), + minFreeCapacityRatio, + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(peakBytes), + succinctBytes(expectedReclaimableUsedBytes), + succinctBytes(expectedActualReclaimedBytes), + succinctBytes(expectedUsedBytes)); + } + } testSettings[] = { + {128 << 20, 0, 0.0, 128 << 20, 0, 0, 0, 0, 0}, + {128 << 20, 0, 0.0, 128 << 20, 0, 32 << 20, 0, 0, 0}, + {128 << 20, 0, 0.0, 128 << 20, 32 << 20, 0, 0, 0, 32 << 20}, + {64 << 20, 0, 0.0, 128 << 20, 96 << 20, 0, 64 << 20, 64 << 20, 32 << 20}, + {64 << 20, 0, 0.0, 128 << 20, 128 << 20, 0, 64 << 20, 64 << 20, 64 << 20}, + {0, 32 << 20, 0.25, 128 << 20, 0, 0, 0, 0}, + {0, 64 << 20, 0.25, 128 << 20, 0, 0, 0, 0}, + {0, 32 << 20, 0.25, 256 << 20, 0, 0, 0, 0}, + {0, 32 << 20, 0.25, 256 << 20, 0, 64 << 20, 0, 0}, + {0, 32 << 20, 0.25, 128 << 20, 96 << 20, 0, 96 << 20, 128 << 20, 0}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 0, 0, 0, 0, 0}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 64 << 20, 0, 0, 0, 64 << 20}, + {128 << 20, 32 << 20, 0.25, 128 << 20, 128 << 20, 0, 0, 0, 128 << 20}, + {64 << 20, + 32 << 20, + 0.25, + 128 << 20, + 64 << 20, + 0, + 64 << 20, + 128 << 20, + 0}, + {128 << 20, + 32 << 20, + 0.25, + 128 << 20, + 64 << 20, + 64 << 20, + 0, + 0, + 64 << 20}, + {64 << 20, + 32 << 20, + 0.25, + 128 << 20, + 64 << 20, + 64 << 20, + 64 << 20, + 128 << 20, + 0}, + {96 << 20, + 32 << 20, + 0.25, + 128 << 20, + 64 << 20, + 64 << 20, + 32 << 20, + 32 << 20, + 32 << 20}, + {32 << 20, + 32 << 20, + 0.5, + 256 << 20, + 256 << 20, + 0, + 224 << 20, + 192 << 20, + 32 << 20}, + {32 << 20, + 64 << 20, + 0.125, + 256 << 20, + 256 << 20, + 0, + 224 << 20, + 192 << 20, + 32 << 20}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig( + testData.minCapacity, + 0, + 0.0, + testData.minFreeCapacity, + testData.minFreeCapacityRatio); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + if (testData.peakBytes > 0) { + void* buffer = task->allocate(testData.peakBytes); + task->free(buffer); + ASSERT_EQ(scopedParticipant->pool()->peakBytes(), testData.peakBytes); + } + + scopedParticipant->shrink(/*reclaimFromAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + if (testData.usedBytes > 0) { + for (int i = 0; i < testData.usedBytes / MB; ++i) { + task->allocate(MB); + } + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + ASSERT_EQ( + scopedParticipant->reclaimableUsedCapacity(), + testData.expectedReclaimableUsedBytes); + + const auto targetBytes = scopedParticipant->reclaimableUsedCapacity(); + const uint64_t prevReclaimedBytes = + scopedParticipant->stats().reclaimedBytes; + ASSERT_EQ( + scopedParticipant->reclaim(targetBytes, 1'000'000), + testData.expectedActualReclaimedBytes); + ASSERT_EQ( + scopedParticipant->pool()->usedBytes(), testData.expectedUsedBytes); + + if (targetBytes != 0) { + ASSERT_EQ(scopedParticipant->stats().numShrinks, 2); + ASSERT_EQ(scopedParticipant->stats().numReclaims, 1); + ASSERT_EQ(scopedParticipant->stats().numGrows, 1); + ASSERT_FALSE(scopedParticipant->stats().aborted); + } else { + ASSERT_EQ(scopedParticipant->stats().numShrinks, 1); + ASSERT_EQ(scopedParticipant->stats().numReclaims, 0); + ASSERT_EQ(scopedParticipant->stats().numGrows, 1); + ASSERT_FALSE(scopedParticipant->stats().aborted); + } + ASSERT_EQ( + scopedParticipant->stats().reclaimedBytes, + prevReclaimedBytes + testData.expectedActualReclaimedBytes); + } +} + +TEST_F(ArbitrationParticipantTest, checkCapacityGrowth) { + struct { + uint64_t maxCapacity; + uint64_t capacity; + uint64_t requestBytes; + bool expectedGrowth; + + std::string debugString() const { + return fmt::format( + "maxCapacity {}, capacity {}, requestBytes {}, expectedGrowth {}", + succinctBytes(maxCapacity), + succinctBytes(capacity), + succinctBytes(requestBytes), + expectedGrowth); + } + } testSettings[] = { + {128 << 20, 32 << 20, 1 << 20, true}, + {128 << 20, 128 << 20, 1 << 20, false}, + {128 << 20, 64 << 20, 64 << 20, true}, + {128 << 20, 128 << 20, 0, true}}; + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(testData.maxCapacity); + const auto config = arbitrationConfig(0); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + task->allocate(testData.capacity); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + ASSERT_EQ( + scopedParticipant->checkCapacityGrowth(testData.requestBytes), + testData.expectedGrowth); + } +} + +TEST_F(ArbitrationParticipantTest, grow) { + struct { + uint64_t maxCapacity; + uint64_t capacity; + uint64_t usedBytes; + uint64_t growthBytes; + uint64_t reservationBytes; + bool expectedFailure; + uint64_t expectedReservationBytes; + uint64_t expectedCapacityAfterGrowth; + + std::string debugString() const { + return fmt::format( + "maxCapacity {}, capacity {}, usedBytes {}, growthBytes {}, reservationBytes {}, expectedFailure {}, expectedReservationBytes {}, expectedCapacityAfterGrowth {}", + succinctBytes(maxCapacity), + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(growthBytes), + succinctBytes(reservationBytes), + expectedFailure, + succinctBytes(expectedReservationBytes), + succinctBytes(expectedCapacityAfterGrowth)); + } + } testSettings[] = { + {256 << 20, 128 << 20, 0, 1 << 20, 0, false, 0, 129 << 20}, + {256 << 20, 128 << 20, 0, 256 << 20, 0, true, 0, 128 << 20}, + {256 << 20, 128 << 20, 0, 0 << 20, 192 << 20, true, 0, 128 << 20}, + {256 << 20, 128 << 20, 0, 32 << 20, 256 << 20, true, 0, 128 << 20}, + {256 << 20, 128 << 20, 0, 32 << 20, 32 << 20, false, 32 << 20, 160 << 20}, + {256 << 20, 128 << 20, 0, 32 << 20, 16 << 20, false, 16 << 20, 160 << 20}, + {256 << 20, 128 << 20, 0, 0, 16 << 20, false, 16 << 20, 128 << 20}, + {256 << 20, 128 << 20, 0, 0, 128 << 20, false, 128 << 20, 128 << 20}, + {256 << 20, + 128 << 20, + 96 << 20, + 0, + 16 << 20, + false, + 112 << 20, + 128 << 20}, + {256 << 20, 128 << 20, 96 << 20, 0, 64 << 20, true, 96 << 20, 128 << 20}, + {256 << 20, + 128 << 20, + 96 << 20, + 8 << 20, + 64 << 20, + true, + 96 << 20, + 128 << 20}, + {256 << 20, + 128 << 20, + 96 << 20, + 128 << 20, + 64 << 20, + false, + 160 << 20, + 256 << 20}, + {256 << 20, + 128 << 20, + 96 << 20, + 256 << 20, + 64 << 20, + true, + 96 << 20, + 128 << 20}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(testData.maxCapacity); + const auto config = arbitrationConfig(0); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + if (testData.usedBytes > 0) { + task->allocate(testData.usedBytes); + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + const uint64_t prevGrowBytes = scopedParticipant->stats().growBytes; + ASSERT_EQ( + !testData.expectedFailure, + scopedParticipant->grow( + testData.growthBytes, testData.reservationBytes)); + ASSERT_EQ( + testData.expectedReservationBytes, + scopedParticipant->pool()->reservedBytes()); + ASSERT_EQ( + scopedParticipant->capacity(), testData.expectedCapacityAfterGrowth); + if (!testData.expectedFailure && testData.reservationBytes > 0) { + static_cast(scopedParticipant->pool()) + ->testingSetReservation( + testData.expectedReservationBytes - testData.reservationBytes); + } + if (testData.expectedFailure) { + ASSERT_EQ(scopedParticipant->stats().growBytes, prevGrowBytes); + } else { + ASSERT_EQ( + scopedParticipant->stats().growBytes, + prevGrowBytes + testData.growthBytes); + } + } +} + +TEST_F(ArbitrationParticipantTest, shrink) { + struct { + uint64_t maxCapacity; + uint64_t minCapacity; + uint64_t capacity; + uint64_t usedBytes; + uint64_t expectedFreeCapacity; + + std::string debugString() const { + return fmt::format( + "maxCapacity {}, minCapacity {}, capacity {}, usedBytes {}, expectedFreeCapacity {}", + succinctBytes(maxCapacity), + succinctBytes(minCapacity), + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(expectedFreeCapacity)); + } + } testSettings[] = { + {256 << 20, 128 << 20, 0, 0, 0}, + {256 << 20, 128 << 20, 64 << 20, 0, 0}, + {256 << 20, 128 << 20, 64 << 20, 32 << 20, 0}, + {256 << 20, 128 << 20, 64 << 20, 64 << 20, 0}, + {256 << 20, 128 << 20, 128 << 20, 64 << 20, 0}, + {256 << 20, 128 << 20, 192 << 20, 64 << 20, 64 << 20}, + {256 << 20, 128 << 20, 256 << 20, 128 << 20, 128 << 20}, + {256 << 20, 128 << 20, 256 << 20, 0, 128 << 20}, + {256 << 20, 128 << 20, 192 << 20, 0, 64 << 20}, + {256 << 20, 128 << 20, 128 << 20, 0, 0}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + for (bool reclaimAll : {false, true}) { + SCOPED_TRACE(fmt::format("reclaimAll {}", reclaimAll)); + + auto task = createTask(testData.maxCapacity); + const auto config = + arbitrationConfig(testData.minCapacity, 0, 0.0, 0, 0.0); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + if (testData.usedBytes > 0) { + task->allocate(testData.usedBytes); + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + const uint64_t prevFreedBytes = scopedParticipant->stats().reclaimedBytes; + const uint32_t prevNumShrunks = scopedParticipant->stats().numShrinks; + if (reclaimAll) { + ASSERT_EQ( + scopedParticipant->shrink(reclaimAll), + testData.capacity - testData.usedBytes); + ASSERT_EQ( + prevFreedBytes + testData.capacity - testData.usedBytes, + scopedParticipant->stats().reclaimedBytes); + } else { + ASSERT_EQ( + scopedParticipant->shrink(reclaimAll), + testData.expectedFreeCapacity); + ASSERT_EQ( + prevFreedBytes + testData.expectedFreeCapacity, + scopedParticipant->stats().reclaimedBytes); + } + ASSERT_EQ(prevNumShrunks + 1, scopedParticipant->stats().numShrinks); + } + } +} + +TEST_F(ArbitrationParticipantTest, abort) { + struct { + uint64_t maxCapacity; + uint64_t minCapacity; + uint64_t capacity; + uint64_t usedBytes; + uint64_t expectedReclaimCapacity; + + std::string debugString() const { + return fmt::format( + "maxCapacity {}, minCapacity {}, capacity {}, usedBytes {}, expectedReclaimCapacity {}", + succinctBytes(maxCapacity), + succinctBytes(minCapacity), + succinctBytes(capacity), + succinctBytes(usedBytes), + succinctBytes(expectedReclaimCapacity)); + } + } testSettings[] = { + {256 << 20, 128 << 20, 0, 0, 0}, + {256 << 20, 128 << 20, 128 << 20, 0, 128 << 20}, + {256 << 20, 128 << 20, 256 << 20, 0, 256 << 20}, + {256 << 20, 128 << 20, 64 << 20, 0, 64 << 20}, + {256 << 20, 128 << 20, 128 << 20, 64 << 20, 128 << 20}, + {256 << 20, 128 << 20, 128 << 20, 128 << 20, 128 << 20}, + {256 << 20, 128 << 20, 256 << 20, 128 << 20, 256 << 20}, + {256 << 20, 128 << 20, 256 << 20, 256 << 20, 256 << 20}, + {256 << 20, 128 << 20, 64 << 20, 32 << 20, 64 << 20}, + {256 << 20, 128 << 20, 64 << 20, 64 << 20, 64 << 20}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + auto task = createTask(testData.maxCapacity); + const auto config = arbitrationConfig(testData.minCapacity, 0, 0.0, 0, 0.0); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimAll=*/true); + scopedParticipant->grow(testData.capacity, 0); + ASSERT_EQ(scopedParticipant->capacity(), testData.capacity); + + if (testData.usedBytes > 0) { + task->allocate(testData.usedBytes); + } + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), testData.usedBytes); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), testData.usedBytes); + + ASSERT_FALSE(scopedParticipant->stats().aborted); + ASSERT_FALSE(scopedParticipant->aborted()); + const uint64_t prevFreedBytes = scopedParticipant->stats().reclaimedBytes; + const uint32_t prevNumShrunks = scopedParticipant->stats().numShrinks; + const uint32_t prevNumReclaims = scopedParticipant->stats().numReclaims; + const std::string abortReason = "test abort"; + try { + VELOX_FAIL(abortReason); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ( + scopedParticipant->abort(std::current_exception()), + testData.expectedReclaimCapacity); + } + ASSERT_TRUE(task->pool()->aborted()); + ASSERT_TRUE(scopedParticipant->stats().aborted); + ASSERT_TRUE(scopedParticipant->aborted()); + ASSERT_EQ( + scopedParticipant->stats().reclaimedBytes, + prevFreedBytes + testData.expectedReclaimCapacity); + ASSERT_EQ(scopedParticipant->stats().numShrinks, prevNumShrunks + 1); + ASSERT_EQ(scopedParticipant->stats().numReclaims, prevNumReclaims); + + try { + VELOX_FAIL(abortReason); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 0); + } + ASSERT_EQ(scopedParticipant->stats().numShrinks, prevNumShrunks + 1); + ASSERT_EQ(scopedParticipant->stats().numReclaims, prevNumReclaims); + ASSERT_TRUE(scopedParticipant->aborted()); + ASSERT_EQ(scopedParticipant->capacity(), 0); + + ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000), 0); + ASSERT_EQ(scopedParticipant->stats().numReclaims, prevNumReclaims + 1); + ASSERT_EQ(scopedParticipant->stats().numShrinks, prevNumShrunks + 2); + } +} + +DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, reclaimLock) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + const uint64_t allocatedBytes = 32 * MB; + for (int i = 0; i < 32; ++i) { + task->allocate(MB); + } + auto scopedParticipant = participant->lock().value(); + + std::atomic_bool reclaim1WaitFlag{false}; + folly::EventCount reclaim1Wait; + std::atomic_bool reclaim1ResumeFlag{false}; + folly::EventCount reclaim1Resume; + std::atomic_bool reclaim2WaitFlag{false}; + folly::EventCount reclaim2Wait; + std::atomic_bool reclaim2ResumeFlag{false}; + folly::EventCount reclaim2Resume; + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::reclaim", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + if (!reclaim1WaitFlag.exchange(true)) { + reclaim1Wait.notifyAll(); + reclaim1Resume.await([&]() { return reclaim1ResumeFlag.load(); }); + return; + } + if (!reclaim2WaitFlag.exchange(true)) { + reclaim2Wait.notifyAll(); + reclaim1Resume.await([&]() { return reclaim1ResumeFlag.load(); }); + return; + } + }))); + + std::atomic_bool abortWaitFlag{false}; + folly::EventCount abortWait; + std::atomic_bool abortResumeFlag{false}; + folly::EventCount abortResume; + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::abortLocked", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + if (!abortWaitFlag.exchange(true)) { + abortWait.notifyAll(); + abortResume.await([&]() { return abortResumeFlag.load(); }); + return; + } + }))); + + std::atomic_bool reclaim1CompletedFlag{false}; + folly::EventCount reclaim1CompletedWait; + std::thread reclaimThread1([&]() { + ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000), 0); + reclaim1CompletedFlag = true; + reclaim1CompletedWait.notifyAll(); + }); + reclaim1Wait.await([&]() { return reclaim1WaitFlag.load(); }); + + std::atomic_bool abortCompletedFlag{false}; + folly::EventCount abortCompletedWait; + std::thread abortThread([&]() { + const std::string abortReason = "test abort"; + try { + VELOX_FAIL(abortReason); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 32 * MB); + } + abortCompletedFlag = true; + abortCompletedWait.notifyAll(); + }); + + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT + ASSERT_FALSE(reclaim1CompletedFlag); + ASSERT_FALSE(abortWaitFlag); + + reclaim1ResumeFlag = true; + reclaim1Resume.notifyAll(); + reclaim1CompletedWait.await([&]() { return reclaim1CompletedFlag.load(); }); + reclaimThread1.join(); + + abortWait.await([&]() { return abortWaitFlag.load(); }); + + std::atomic_bool reclaim2CompletedFlag{false}; + folly::EventCount reclaim2CompletedWait; + std::thread reclaimThread2([&]() { + ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000), 0); + reclaim2CompletedFlag = true; + reclaim2CompletedWait.notifyAll(); + }); + + std::this_thread::sleep_for(std::chrono::seconds(1)); // NOLINT + ASSERT_FALSE(abortCompletedFlag); + ASSERT_FALSE(reclaim2WaitFlag); + + abortResumeFlag = true; + abortResume.notifyAll(); + abortCompletedWait.await([&]() { return abortCompletedFlag.load(); }); + abortThread.join(); + + reclaim2ResumeFlag = true; + reclaim2Resume.notifyAll(); + reclaim2CompletedWait.await([&]() { return reclaim2CompletedFlag.load(); }); + reclaimThread2.join(); + + ASSERT_TRUE(task->pool()->aborted()); + ASSERT_TRUE(task->abortError() != nullptr); + ASSERT_TRUE(scopedParticipant->aborted()); + ASSERT_EQ(scopedParticipant->capacity(), 0); + ASSERT_EQ(scopedParticipant->pool()->usedBytes(), 0); + ASSERT_EQ(scopedParticipant->stats().numReclaims, 2); + ASSERT_EQ(scopedParticipant->stats().numShrinks, 3); + ASSERT_EQ(scopedParticipant->stats().reclaimedBytes, 32 << 20); +} + +DEBUG_ONLY_TEST_F(ArbitrationParticipantTest, waitForReclaimOrAbort) { + struct { + uint64_t waitTimeUs; + bool pendingReclaim; + uint64_t reclaimWaitMs{0}; + bool expectedTimeout; + + std::string debugString() const { + return fmt::format( + "waitTime {}, pendingReclaim {}, reclaimWait {}, expectedTimeout {}", + succinctMicros(waitTimeUs), + pendingReclaim, + succinctMillis(reclaimWaitMs), + expectedTimeout); + } + } testSettings[] = { + {0, true, 1'000, true}, + {0, false, 1'000, true}, + {1'000'000, true, 1'000, false}, + {1'000'000, true, 1'000, false}}; + + for (const auto& testData : testSettings) { + SCOPED_TRACE(testData.debugString()); + + std::atomic_bool reclaimWaitFlag{false}; + folly::EventCount reclaimWait; + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::reclaim", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + reclaimWaitFlag = true; + reclaimWait.notifyAll(); + std::this_thread::sleep_for( + std::chrono::milliseconds(testData.reclaimWaitMs)); // NOLINT + }))); + + SCOPED_TESTVALUE_SET( + "facebook::velox::memory::ArbitrationParticipant::abortLocked", + std::function( + ([&](ArbitrationParticipant* /*unused*/) { + reclaimWaitFlag = true; + reclaimWait.notifyAll(); + std::this_thread::sleep_for( + std::chrono::milliseconds(testData.reclaimWaitMs)); // NOLINT + }))); + + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(); + auto participant = + ArbitrationParticipant::create(10, task->pool(), &config); + task->allocate(MB); + auto scopedParticipant = participant->lock().value(); + + std::thread reclaimThread([&]() { + if (testData.pendingReclaim) { + ASSERT_EQ(scopedParticipant->reclaim(MB, 1'000'000), MB); + } else { + const std::string abortReason = "test abort"; + try { + VELOX_FAIL(abortReason); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), MB); + } + } + }); + reclaimWait.await([&]() { return reclaimWaitFlag.load(); }); + ASSERT_EQ( + scopedParticipant->waitForReclaimOrAbort(testData.waitTimeUs), + !testData.expectedTimeout); + reclaimThread.join(); + } +} + +TEST_F(ArbitrationParticipantTest, capacityCheck) { + auto task = createTask(256 << 20); + const auto config = arbitrationConfig(512 << 20); + VELOX_ASSERT_THROW( + ArbitrationParticipant::create(0, task->pool(), &config), + "The min capacity is larger than the max capacity for memory pool"); +} + +TEST_F(ArbitrationParticipantTest, arbitrationCandidate) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(0, 0, 0.0, 0, 0.0); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + + auto scopedParticipant = participant->lock().value(); + scopedParticipant->shrink(/*reclaimAll=*/true); + scopedParticipant->grow(32 << 20, 0); + ASSERT_EQ(scopedParticipant->capacity(), 32 << 20); + task->allocate(MB); + ASSERT_EQ(scopedParticipant->pool()->reservedBytes(), MB); + + ArbitrationCandidate candidateWithFreeCapacityOnly( + participant->lock().value(), /*freeCapacityOnly=*/true); + ASSERT_EQ( + candidateWithFreeCapacityOnly.participant->name(), + scopedParticipant->name()); + ASSERT_EQ(candidateWithFreeCapacityOnly.reclaimableUsedCapacity, 0); + ASSERT_EQ(candidateWithFreeCapacityOnly.reclaimableFreeCapacity, 31 << 20); + ASSERT_EQ( + candidateWithFreeCapacityOnly.toString(), + "TaskPool-0 RECLAIMABLE_USED_CAPACITY 0B RECLAIMABLE_FREE_CAPACITY 31.00MB"); + + ArbitrationCandidate candidate( + participant->lock().value(), /*freeCapacityOnly=*/false); + ASSERT_EQ(candidate.participant->name(), scopedParticipant->name()); + ASSERT_EQ(candidate.reclaimableUsedCapacity, MB); + ASSERT_EQ(candidate.reclaimableFreeCapacity, 31 << 20); + ASSERT_EQ( + candidate.toString(), + "TaskPool-0 RECLAIMABLE_USED_CAPACITY 1.00MB RECLAIMABLE_FREE_CAPACITY 31.00MB"); +} + +TEST_F(ArbitrationParticipantTest, arbitrationOperation) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(0, 0, 0.0, 0, 0.0); + const int participantId{10}; + auto participant = + ArbitrationParticipant::create(participantId, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + const int requestBytes = 1 << 20; + const int opTimeoutMs = 1'000'000; + ArbitrationOperation op( + participant->lock().value(), requestBytes, opTimeoutMs); + VELOX_ASSERT_THROW( + ArbitrationOperation(participant->lock().value(), 0, opTimeoutMs), ""); + ASSERT_EQ(op.requestBytes(), requestBytes); + ASSERT_FALSE(op.aborted()); + ASSERT_FALSE(op.hasTimeout()); + ASSERT_EQ(op.allocatedBytes(), 0); + ASSERT_LE(op.timeoutMs(), opTimeoutMs); + + std::this_thread::sleep_for(std::chrono::milliseconds(1'000)); // NOLINT + ASSERT_GE(op.executionTimeMs(), 1'000); + ASSERT_LE(op.timeoutMs(), opTimeoutMs - 1'000); + ASSERT_EQ(op.maxGrowBytes(), 0); + ASSERT_EQ(op.minGrowBytes(), 0); + ASSERT_EQ(op.localArbitrationWaitTimeUs(), 0); + ASSERT_EQ(op.globalArbitrationWaitTimeUs(), 0); + ASSERT_FALSE(op.hasTimeout()); + VELOX_ASSERT_THROW(op.setGrowTargets(), ""); + ASSERT_EQ(op.requestBytes(), requestBytes); + ASSERT_EQ(op.maxGrowBytes(), 0); + ASSERT_EQ(op.minGrowBytes(), 0); + + ASSERT_EQ(op.localArbitrationWaitTimeUs(), 0); + ASSERT_EQ(op.globalArbitrationWaitTimeUs(), 0); + + ASSERT_EQ(op.state(), ArbitrationOperation::State::kInit); + ASSERT_FALSE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + VELOX_ASSERT_THROW(op.setLocalArbitrationWaitTimeUs(2'000), ""); + VELOX_ASSERT_THROW(op.setGlobalArbitrationWaitTimeUs(2'000), ""); + op.start(); + op.setGrowTargets(); + ASSERT_EQ(op.requestBytes(), requestBytes); + ASSERT_EQ(op.maxGrowBytes(), requestBytes); + ASSERT_EQ(op.minGrowBytes(), 0); + VELOX_ASSERT_THROW(op.setGrowTargets(), ""); + ASSERT_EQ(op.requestBytes(), requestBytes); + ASSERT_EQ(op.maxGrowBytes(), requestBytes); + ASSERT_EQ(op.minGrowBytes(), 0); + + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + ASSERT_EQ(op.state(), ArbitrationOperation::State::kRunning); + + VELOX_ASSERT_THROW(op.setLocalArbitrationWaitTimeUs(2'000), ""); + ASSERT_EQ(op.localArbitrationWaitTimeUs(), 0); + op.setGlobalArbitrationWaitTimeUs(2'000); + ASSERT_EQ(op.globalArbitrationWaitTimeUs(), 2'000); + VELOX_ASSERT_THROW(op.setGlobalArbitrationWaitTimeUs(2'000), ""); + op.allocatedBytes() = op.maxGrowBytes(); + + op.finish(); + ASSERT_EQ(op.state(), ArbitrationOperation::State::kFinished); + ASSERT_FALSE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + VELOX_ASSERT_THROW(op.setLocalArbitrationWaitTimeUs(2'000), ""); + VELOX_ASSERT_THROW(op.setGlobalArbitrationWaitTimeUs(2'000), ""); + ASSERT_FALSE(op.hasTimeout()); + const auto execTimeMs = op.executionTimeMs(); + std::this_thread::sleep_for(std::chrono::milliseconds(1'000)); // NOLINT + ASSERT_EQ(op.executionTimeMs(), execTimeMs); + ASSERT_FALSE(op.hasTimeout()); + + // Operation timeout. + { + ArbitrationOperation timedOutOp(participant->lock().value(), 1 << 20, 100); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // NOLINT + ASSERT_TRUE(timedOutOp.hasTimeout()); + + ArbitrationOperation noTimedoutOp( + participant->lock().value(), 1 << 20, 100); + noTimedoutOp.start(); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); // NOLINT + noTimedoutOp.finish(); + ASSERT_FALSE(noTimedoutOp.hasTimeout()); + } + + // Operation abort. + { + ArbitrationOperation abortOp(participant->lock().value(), 1 << 20, 100); + ASSERT_FALSE(abortOp.aborted()); + try { + VELOX_FAIL("abort op"); + } catch (const VeloxRuntimeError& e) { + ASSERT_EQ(scopedParticipant->abort(std::current_exception()), 0); + } + ASSERT_TRUE(abortOp.aborted()); + + ArbitrationOperation abortCheckOp( + participant->lock().value(), 1 << 20, 100); + ASSERT_TRUE(abortCheckOp.aborted()); + } +} + +TEST_F(ArbitrationParticipantTest, arbitrationOperationWait) { + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(0, 0, 0.0, 0, 0.0); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + auto scopedParticipant = participant->lock().value(); + const int requestBytes = 1 << 20; + const int opTimeoutMs = 1'000'000; + ArbitrationOperation op1( + participant->lock().value(), requestBytes, opTimeoutMs); + ArbitrationOperation op2( + participant->lock().value(), requestBytes, opTimeoutMs); + ArbitrationOperation op3( + participant->lock().value(), requestBytes, opTimeoutMs); + ArbitrationOperation op4(participant->lock().value(), requestBytes, 1'000); + ASSERT_FALSE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + + op1.start(); + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + ASSERT_EQ(op1.state(), ArbitrationOperation::State::kRunning); + + std::thread op2Thread([&]() { + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + op2.start(); + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_FALSE(op2.hasTimeout()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 2); + ASSERT_EQ(op3.state(), ArbitrationOperation::State::kWaiting); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // NOLINT + ASSERT_EQ(scopedParticipant->numWaitingOps(), 2); + ASSERT_EQ(op3.state(), ArbitrationOperation::State::kWaiting); + op2.finish(); + }); + + while (scopedParticipant->numWaitingOps() != 1) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + } + + std::thread op3Thread([&]() { + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + op3.start(); + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_FALSE(op3.hasTimeout()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 1); + ASSERT_EQ(op4.state(), ArbitrationOperation::State::kWaiting); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // NOLINT + ASSERT_EQ(scopedParticipant->numWaitingOps(), 1); + ASSERT_EQ(op4.state(), ArbitrationOperation::State::kWaiting); + op3.finish(); + }); + + while (scopedParticipant->numWaitingOps() != 2) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + } + + std::thread op4Thread([&]() { + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + op4.start(); + ASSERT_TRUE(scopedParticipant->hasRunningOp()); + ASSERT_TRUE(op4.hasTimeout()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); // NOLINT + op4.finish(); + }); + + while (scopedParticipant->numWaitingOps() != 3) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); // NOLINT + } + + ASSERT_EQ(op2.state(), ArbitrationOperation::State::kWaiting); + ASSERT_EQ(op2.state(), ArbitrationOperation::State::kWaiting); + ASSERT_EQ(op2.state(), ArbitrationOperation::State::kWaiting); + + std::this_thread::sleep_for(std::chrono::seconds(1)); + op1.finish(); + ASSERT_EQ(op1.state(), ArbitrationOperation::State::kFinished); + ASSERT_FALSE(op1.hasTimeout()); + ASSERT_GE(op1.executionTimeMs(), 1'000); + + op2Thread.join(); + ASSERT_EQ(op2.state(), ArbitrationOperation::State::kFinished); + ASSERT_GE(op2.executionTimeMs(), 1'000 + 500); + + op3Thread.join(); + ASSERT_EQ(op3.state(), ArbitrationOperation::State::kFinished); + ASSERT_GE(op3.executionTimeMs(), 1'000 + 500 + 500); + + op4Thread.join(); + ASSERT_EQ(op4.state(), ArbitrationOperation::State::kFinished); + ASSERT_GE(op4.executionTimeMs(), 1'000 + 500 + 500); + + ASSERT_FALSE(scopedParticipant->hasRunningOp()); + ASSERT_EQ(scopedParticipant->numWaitingOps(), 0); + + ASSERT_EQ(scopedParticipant->stats().numRequests, 4); +} + +TEST_F(ArbitrationParticipantTest, arbitrationOperationFuzzerTest) { + const int numThreads = 10; + const int numOpsPerThread = 100; + auto task = createTask(kMemoryCapacity); + const auto config = arbitrationConfig(0, 0, 0.0, 0, 0.0); + auto participant = ArbitrationParticipant::create(10, task->pool(), &config); + + std::vector arbitrationThreads; + for (int i = 0; i < numThreads; ++i) { + arbitrationThreads.emplace_back([&, i]() { + folly::Random::DefaultGenerator rng; + rng.seed(i); + for (int j = 0; j < numOpsPerThread; ++j) { + const int numExecutionTimeUs = folly::Random::rand32(0, 1'000, rng); + ArbitrationOperation op(participant->lock().value(), 1 << 20, 1'000); + op.start(); + std::this_thread::sleep_for( + std::chrono::microseconds(numExecutionTimeUs)); // NOLINT + op.finish(); + } + }); + } + for (auto& thread : arbitrationThreads) { + thread.join(); + } + + ASSERT_EQ(participant->stats().numRequests, numThreads * numOpsPerThread); +} + +TEST_F(ArbitrationParticipantTest, arbitrationOperationState) { + ASSERT_EQ( + ArbitrationOperation::stateName(ArbitrationOperation::State::kInit), + "init"); + ASSERT_EQ( + ArbitrationOperation::stateName(ArbitrationOperation::State::kWaiting), + "waiting"); + ASSERT_EQ( + ArbitrationOperation::stateName(ArbitrationOperation::State::kRunning), + "running"); + ASSERT_EQ( + ArbitrationOperation::stateName(ArbitrationOperation::State::kFinished), + "finished"); + ASSERT_EQ( + ArbitrationOperation::stateName( + static_cast(10)), + "unknown state: 10"); +} +} // namespace +} // namespace facebook::velox::memory