Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] In OAP Velox fork, simplify shrinking/spilling code for Velox memory pool #3586

Merged
merged 14 commits into from
Nov 2, 2023
10 changes: 3 additions & 7 deletions cpp/velox/benchmarks/QueryBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,10 @@ const std::string getFilePath(const std::string& fileName) {

// Used by unit test and benchmark.
std::shared_ptr<ResultIterator> getResultIterator(
std::shared_ptr<velox::memory::MemoryPool> veloxPool,
VeloxMemoryManager* memoryManager,
Runtime* runtime,
const std::vector<std::shared_ptr<SplitInfo>>& setScanInfos,
std::shared_ptr<const facebook::velox::core::PlanNode>& veloxPlan) {
auto ctxPool = veloxPool->addAggregateChild(
"query_benchmark_result_iterator", facebook::velox::memory::MemoryReclaimer::create());

std::vector<std::shared_ptr<ResultIterator>> inputIter;
std::unordered_map<std::string, std::string> sessionConf = {};
auto veloxPlanConverter =
Expand All @@ -57,7 +54,7 @@ std::shared_ptr<ResultIterator> getResultIterator(
veloxPlanConverter->splitInfos(), veloxPlan->leafPlanNodeIds(), scanInfos, scanIds, streamIds);

auto wholestageIter = std::make_unique<WholeStageResultIteratorFirstStage>(
ctxPool,
memoryManager,
veloxPlan,
scanIds,
setScanInfos,
Expand All @@ -78,7 +75,6 @@ auto BM = [](::benchmark::State& state,

auto memoryManager = getDefaultMemoryManager();
auto runtime = Runtime::create(kVeloxRuntimeKind);
auto veloxPool = memoryManager->getAggregateMemoryPool();

std::vector<std::shared_ptr<SplitInfo>> scanInfos;
scanInfos.reserve(datasetPaths.size());
Expand All @@ -96,7 +92,7 @@ auto BM = [](::benchmark::State& state,

runtime->parsePlan(reinterpret_cast<uint8_t*>(plan.data()), plan.size());
std::shared_ptr<const facebook::velox::core::PlanNode> veloxPlan;
auto resultIter = getResultIterator(veloxPool, runtime, scanInfos, veloxPlan);
auto resultIter = getResultIterator(memoryManager.get(), runtime, scanInfos, veloxPlan);
auto outputSchema = toArrowSchema(veloxPlan->outputType(), defaultLeafVeloxMemoryPool().get());
while (resultIter->hasNext()) {
auto array = resultIter->next()->exportArrowArray();
Expand Down
7 changes: 3 additions & 4 deletions cpp/velox/compute/VeloxRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ std::shared_ptr<ResultIterator> VeloxRuntime::createResultIterator(
printSessionConf(sessionConf);
#endif

auto veloxPool = getAggregateVeloxPool(memoryManager);

VeloxPlanConverter veloxPlanConverter(inputs, getLeafVeloxPool(memoryManager).get(), sessionConf);
veloxPlan_ = veloxPlanConverter.toVeloxPlan(substraitPlan_);

Expand All @@ -91,15 +89,16 @@ std::shared_ptr<ResultIterator> VeloxRuntime::createResultIterator(
// Separate the scan ids and stream ids, and get the scan infos.
getInfoAndIds(veloxPlanConverter.splitInfos(), veloxPlan_->leafPlanNodeIds(), scanInfos, scanIds, streamIds);

auto* vmm = toVeloxMemoryManager(memoryManager);
if (scanInfos.size() == 0) {
// Source node is not required.
auto wholestageIter = std::make_unique<WholeStageResultIteratorMiddleStage>(
veloxPool, veloxPlan_, streamIds, spillDir, sessionConf, taskInfo_);
vmm, veloxPlan_, streamIds, spillDir, sessionConf, taskInfo_);
auto resultIter = std::make_shared<ResultIterator>(std::move(wholestageIter), this);
return resultIter;
} else {
auto wholestageIter = std::make_unique<WholeStageResultIteratorFirstStage>(
veloxPool, veloxPlan_, scanIds, scanInfos, streamIds, spillDir, sessionConf, taskInfo_);
vmm, veloxPlan_, scanIds, scanInfos, streamIds, spillDir, sessionConf, taskInfo_);
auto resultIter = std::make_shared<ResultIterator>(std::move(wholestageIter), this);
return resultIter;
}
Expand Down
14 changes: 7 additions & 7 deletions cpp/velox/compute/VeloxRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ class VeloxRuntime final : public Runtime {
explicit VeloxRuntime(const std::unordered_map<std::string, std::string>& confMap);

static std::shared_ptr<facebook::velox::memory::MemoryPool> getAggregateVeloxPool(MemoryManager* memoryManager) {
if (auto veloxMemoryManager = dynamic_cast<VeloxMemoryManager*>(memoryManager)) {
return veloxMemoryManager->getAggregateMemoryPool();
} else {
GLUTEN_CHECK(false, "Should use VeloxMemoryManager here.");
}
return toVeloxMemoryManager(memoryManager)->getAggregateMemoryPool();
}

static std::shared_ptr<facebook::velox::memory::MemoryPool> getLeafVeloxPool(MemoryManager* memoryManager) {
return toVeloxMemoryManager(memoryManager)->getLeafMemoryPool();
}

static VeloxMemoryManager* toVeloxMemoryManager(MemoryManager* memoryManager) {
if (auto veloxMemoryManager = dynamic_cast<VeloxMemoryManager*>(memoryManager)) {
return veloxMemoryManager->getLeafMemoryPool();
return veloxMemoryManager;
} else {
GLUTEN_CHECK(false, "Should use VeloxMemoryManager here.");
GLUTEN_CHECK(false, "Velox memory manager should be used for Velox runtime.");
}
}

Expand Down
28 changes: 12 additions & 16 deletions cpp/velox/compute/WholeStageResultIterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ const std::string kHiveDefaultPartition = "__HIVE_DEFAULT_PARTITION__";
} // namespace

WholeStageResultIterator::WholeStageResultIterator(
std::shared_ptr<facebook::velox::memory::MemoryPool> pool,
VeloxMemoryManager* memoryManager,
const std::shared_ptr<const facebook::velox::core::PlanNode>& planNode,
const std::unordered_map<std::string, std::string>& confMap,
const SparkTaskInfo& taskInfo)
: veloxPlan_(planNode), confMap_(confMap), taskInfo_(taskInfo), pool_(pool) {
: veloxPlan_(planNode), confMap_(confMap), taskInfo_(taskInfo), memoryManager_(memoryManager) {
#ifdef ENABLE_HDFS
updateHdfsTokens();
#endif
Expand All @@ -100,7 +100,7 @@ std::shared_ptr<velox::core::QueryCtx> WholeStageResultIterator::createNewVeloxQ
facebook::velox::core::QueryConfig{getQueryContextConf()},
connectorConfigs,
gluten::VeloxBackend::get()->getAsyncDataCache(),
pool_,
memoryManager_->getAggregateMemoryPool(),
nullptr,
"");
return ctx;
Expand Down Expand Up @@ -153,15 +153,10 @@ class ConditionalSuspendedSection {
} // namespace

int64_t WholeStageResultIterator::spillFixedSize(int64_t size) {
std::string poolName{pool_->root()->name() + "/" + pool_->name()};
auto pool = memoryManager_->getAggregateMemoryPool();
std::string poolName{pool->root()->name() + "/" + pool->name()};
std::string logPrefix{"Spill[" + poolName + "]: "};
VLOG(2) << logPrefix << "Trying to reclaim " << size << " bytes of data...";
VLOG(2) << logPrefix << "Pool has reserved " << pool_->currentBytes() << "/" << pool_->root()->reservedBytes() << "/"
<< pool_->root()->capacity() << "/" << pool_->root()->maxCapacity() << " bytes.";
VLOG(2) << logPrefix << "Shrinking...";
int64_t shrunken = pool_->shrinkManaged(pool_.get(), size);
VLOG(2) << logPrefix << shrunken << " bytes released from shrinking.";

int64_t shrunken = memoryManager_->shrink(size);
// todo return the actual spilled size?
if (spillStrategy_ == "auto") {
int64_t remaining = size - shrunken;
Expand All @@ -176,7 +171,8 @@ int64_t WholeStageResultIterator::spillFixedSize(int64_t size) {
// suspend the driver when we are on it
ConditionalSuspendedSection noCancel(thisDriver, thisDriver != nullptr);
velox::exec::MemoryReclaimer::Stats status;
uint64_t spilledOut = pool_->reclaim(remaining, status);
auto* mm = memoryManager_->getMemoryManager();
uint64_t spilledOut = mm->arbitrator()->shrinkMemory({pool}, remaining); // this conducts spilling
LOG(INFO) << logPrefix << "Successfully spilled out " << spilledOut << " bytes.";
uint64_t total = shrunken + spilledOut;
VLOG(2) << logPrefix << "Successfully reclaimed total " << total << " bytes.";
Expand Down Expand Up @@ -400,15 +396,15 @@ std::shared_ptr<velox::Config> WholeStageResultIterator::createConnectorConfig()
}

WholeStageResultIteratorFirstStage::WholeStageResultIteratorFirstStage(
std::shared_ptr<velox::memory::MemoryPool> pool,
VeloxMemoryManager* memoryManager,
const std::shared_ptr<const velox::core::PlanNode>& planNode,
const std::vector<velox::core::PlanNodeId>& scanNodeIds,
const std::vector<std::shared_ptr<SplitInfo>>& scanInfos,
const std::vector<velox::core::PlanNodeId>& streamIds,
const std::string spillDir,
const std::unordered_map<std::string, std::string>& confMap,
const SparkTaskInfo& taskInfo)
: WholeStageResultIterator(pool, planNode, confMap, taskInfo),
: WholeStageResultIterator(memoryManager, planNode, confMap, taskInfo),
scanNodeIds_(scanNodeIds),
scanInfos_(scanInfos),
streamIds_(streamIds) {
Expand Down Expand Up @@ -495,13 +491,13 @@ void WholeStageResultIteratorFirstStage::constructPartitionColumns(
}

WholeStageResultIteratorMiddleStage::WholeStageResultIteratorMiddleStage(
std::shared_ptr<velox::memory::MemoryPool> pool,
VeloxMemoryManager* memoryManager,
const std::shared_ptr<const velox::core::PlanNode>& planNode,
const std::vector<velox::core::PlanNodeId>& streamIds,
const std::string spillDir,
const std::unordered_map<std::string, std::string>& confMap,
const SparkTaskInfo& taskInfo)
: WholeStageResultIterator(pool, planNode, confMap, taskInfo), streamIds_(streamIds) {
: WholeStageResultIterator(memoryManager, planNode, confMap, taskInfo), streamIds_(streamIds) {
std::unordered_set<velox::core::PlanNodeId> emptySet;
velox::core::PlanFragment planFragment{planNode, velox::core::ExecutionStrategy::kUngrouped, 1, emptySet};
std::shared_ptr<velox::core::QueryCtx> queryCtx = createNewVeloxQueryCtx();
Expand Down
8 changes: 4 additions & 4 deletions cpp/velox/compute/WholeStageResultIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace gluten {
class WholeStageResultIterator : public ColumnarBatchIterator {
public:
WholeStageResultIterator(
std::shared_ptr<facebook::velox::memory::MemoryPool> pool,
VeloxMemoryManager* memoryManager,
const std::shared_ptr<const facebook::velox::core::PlanNode>& planNode,
const std::unordered_map<std::string, std::string>& confMap,
const SparkTaskInfo& taskInfo);
Expand Down Expand Up @@ -96,7 +96,7 @@ class WholeStageResultIterator : public ColumnarBatchIterator {
const std::unordered_map<std::string, facebook::velox::RuntimeMetric>& runtimeStats,
const std::string& metricId);

std::shared_ptr<facebook::velox::memory::MemoryPool> pool_;
VeloxMemoryManager* memoryManager_;

// spill
std::string spillStrategy_;
Expand All @@ -113,7 +113,7 @@ class WholeStageResultIterator : public ColumnarBatchIterator {
class WholeStageResultIteratorFirstStage final : public WholeStageResultIterator {
public:
WholeStageResultIteratorFirstStage(
std::shared_ptr<facebook::velox::memory::MemoryPool> pool,
VeloxMemoryManager* memoryManager,
const std::shared_ptr<const facebook::velox::core::PlanNode>& planNode,
const std::vector<facebook::velox::core::PlanNodeId>& scanNodeIds,
const std::vector<std::shared_ptr<SplitInfo>>& scanInfos,
Expand All @@ -137,7 +137,7 @@ class WholeStageResultIteratorFirstStage final : public WholeStageResultIterator
class WholeStageResultIteratorMiddleStage final : public WholeStageResultIterator {
public:
WholeStageResultIteratorMiddleStage(
std::shared_ptr<facebook::velox::memory::MemoryPool> pool,
VeloxMemoryManager* memoryManager,
const std::shared_ptr<const facebook::velox::core::PlanNode>& planNode,
const std::vector<facebook::velox::core::PlanNodeId>& streamIds,
const std::string spillDir,
Expand Down
53 changes: 34 additions & 19 deletions cpp/velox/memory/VeloxMemoryManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "VeloxMemoryManager.h"
#include "velox/common/memory/MallocAllocator.h"
#include "velox/common/memory/MemoryPool.h"
#include "velox/exec/MemoryReclaimer.h"

#include "memory/ArrowMemoryPool.h"
#include "utils/exception.h"
Expand Down Expand Up @@ -62,26 +63,28 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
}

void reserveMemory(velox::memory::MemoryPool* pool, uint64_t) override {
growPool(pool, memoryPoolInitCapacity_);
std::lock_guard<std::recursive_mutex> l(mutex_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍
could we remove AllocationListener's mutex?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep that mutex for now since this one only works in Velox backend.

growPoolLocked(pool, memoryPoolInitCapacity_);
}

uint64_t releaseMemory(velox::memory::MemoryPool* pool, uint64_t bytes) override {
uint64_t freeBytes = pool->shrink(bytes);
listener_->allocationChanged(-freeBytes);
if (bytes == 0 && pool->capacity() != 0) {
// So far only MemoryManager::dropPool() calls with 0 bytes. Let's assert the pool
// gives all capacity back to Spark
//
// We are likely in destructor, do not throw. INFO log is fine since we have leak checks from Spark's memory
// manager
LOG(INFO) << "Memory pool " << pool->name() << " not completely shrunken when Memory::dropPool() is called";
}
return freeBytes;
void releaseMemory(velox::memory::MemoryPool* pool) override {
std::lock_guard<std::recursive_mutex> l(mutex_);
releaseMemoryLocked(pool);
}

uint64_t shrinkMemory(const std::vector<std::shared_ptr<velox::memory::MemoryPool>>& pools, uint64_t targetBytes)
override {
GLUTEN_CHECK(false, "Not implemented");
facebook::velox::exec::MemoryReclaimer::Stats status;
GLUTEN_CHECK(pools.size() == 1, "Should shrink a single pool at a time");
std::lock_guard<std::recursive_mutex> l(mutex_); // FIXME: Do we have recursive locking for this mutex?
auto pool = pools.at(0);
const uint64_t oldCapacity = pool->capacity();
uint64_t spilledOut = pool->reclaim(targetBytes, status); // ignore the output
uint64_t shrunken = pool->shrink(0);
const uint64_t newCapacity = pool->capacity();
uint64_t total = oldCapacity - newCapacity;
listener_->allocationChanged(-total);
return total;
}

bool growMemory(
Expand All @@ -91,7 +94,10 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
GLUTEN_CHECK(candidatePools.size() == 1, "ListenableArbitrator should only be used within a single root pool");
auto candidate = candidatePools.back();
GLUTEN_CHECK(pool->root() == candidate.get(), "Illegal state in ListenableArbitrator");
growPool(pool, targetBytes);
{
std::lock_guard<std::recursive_mutex> l(mutex_);
growPoolLocked(pool, targetBytes);
}
return true;
}

Expand All @@ -105,12 +111,18 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator {
}

private:
void growPool(velox::memory::MemoryPool* pool, uint64_t bytes) {
void growPoolLocked(velox::memory::MemoryPool* pool, uint64_t bytes) {
listener_->allocationChanged(bytes);
pool->grow(bytes);
}

void releaseMemoryLocked(velox::memory::MemoryPool* pool) {
uint64_t freeBytes = pool->shrink(0);
listener_->allocationChanged(-freeBytes);
}

gluten::AllocationListener* listener_;
std::recursive_mutex mutex_;
inline static std::string kind_ = "GLUTEN";
};

Expand Down Expand Up @@ -194,14 +206,17 @@ MemoryUsageStats collectMemoryUsageStatsInternal(const velox::memory::MemoryPool
return stats;
}

int64_t shrinkVeloxMemoryPool(velox::memory::MemoryPool* pool, int64_t size) {
int64_t shrinkVeloxMemoryPool(velox::memory::MemoryManager* mm, velox::memory::MemoryPool* pool, int64_t size) {
std::string poolName{pool->root()->name() + "/" + pool->name()};
std::string logPrefix{"Shrink[" + poolName + "]: "};
VLOG(2) << logPrefix << "Trying to shrink " << size << " bytes of data...";
VLOG(2) << logPrefix << "Pool has reserved " << pool->currentBytes() << "/" << pool->root()->reservedBytes() << "/"
<< pool->root()->capacity() << "/" << pool->root()->maxCapacity() << " bytes.";
VLOG(2) << logPrefix << "Shrinking...";
int64_t shrunken = pool->shrinkManaged(pool, size);
const uint64_t oldCapacity = pool->capacity();
mm->arbitrator()->releaseMemory(pool);
const uint64_t newCapacity = pool->capacity();
int64_t shrunken = oldCapacity - newCapacity;
VLOG(2) << logPrefix << shrunken << " bytes released from shrinking.";
return shrunken;
}
Expand All @@ -212,7 +227,7 @@ const MemoryUsageStats VeloxMemoryManager::collectMemoryUsageStats() const {
}

const int64_t VeloxMemoryManager::shrink(int64_t size) {
return shrinkVeloxMemoryPool(veloxAggregatePool_.get(), size);
return shrinkVeloxMemoryPool(veloxMemoryManager_.get(), veloxAggregatePool_.get(), size);
}

namespace {
Expand Down
4 changes: 4 additions & 0 deletions cpp/velox/memory/VeloxMemoryManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class VeloxMemoryManager final : public MemoryManager {
return veloxLeafPool_;
}

facebook::velox::memory::MemoryManager* getMemoryManager() const {
return veloxMemoryManager_.get();
}

arrow::MemoryPool* getArrowMemoryPool() override {
return arrowPool_.get();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.glutenproject.memory.nmm;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

// For debugging purpose only
public class LoggingReservationListener implements ReservationListener {
private static final Logger LOGGER = LoggerFactory.getLogger(LoggingReservationListener.class);

private final ReservationListener delegated;

public LoggingReservationListener(ReservationListener delegated) {
this.delegated = delegated;
}

@Override
public long reserve(long size) {
long before = getUsedBytes();
long reserved = delegated.reserve(size);
long after = getUsedBytes();
LOGGER.info(
String.format(
"Reservation[%s]: %d + %d(%d) = %d", this.toString(), before, reserved, size, after));
return reserved;
}

@Override
public long unreserve(long size) {
long before = getUsedBytes();
long unreserved = delegated.unreserve(size);
long after = getUsedBytes();
LOGGER.info(
String.format(
"Unreservation[%s]: %d - %d(%d) = %d",
this.toString(), before, unreserved, size, after));
return unreserved;
}

@Override
public long getUsedBytes() {
return delegated.getUsedBytes();
}
}
Loading
Loading