Skip to content

Commit

Permalink
Add per-test memory sanity check in OperatorTestBase
Browse files Browse the repository at this point in the history
  • Loading branch information
tanjialiang committed Apr 2, 2024
1 parent 8acb197 commit b9e2eb3
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 26 deletions.
1 change: 1 addition & 0 deletions velox/common/memory/tests/SharedArbitratorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class SharedArbitrationTest : public exec::test::HiveConnectorTestBase {
}

void TearDown() override {
vector_.reset();
HiveConnectorTestBase::TearDown();
}

Expand Down
3 changes: 3 additions & 0 deletions velox/exec/tests/DriverTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class DriverTest : public OperatorTestBase {
// NOTE: destroy the tasks first to release all the allocated memory held
// by the plan nodes (Values) in tasks.
tasks_.clear();
waitForAllTasksToBeDeleted();

if (wakeupInitialized_) {
wakeupCancelled_ = true;
Expand Down Expand Up @@ -1480,4 +1481,6 @@ TEST_F(OpCallStatusTest, basic) {

task->start(1, 1);
ASSERT_TRUE(waitForTaskCompletion(task.get(), 600'000'000));
task.reset();
waitForAllTasksToBeDeleted();
};
40 changes: 38 additions & 2 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,11 @@ class HashJoinTest : public HiveConnectorTestBase {
.allowLazyVector = false};
}

void TearDown() override {
waitForAllTasksToBeDeleted();
HiveConnectorTestBase::TearDown();
}

// Make splits with each plan node having a number of source files.
SplitInput makeSpiltInput(
const std::vector<core::PlanNodeId>& nodeIds,
Expand Down Expand Up @@ -6399,7 +6404,6 @@ TEST_F(HashJoinTest, maxSpillBytes) {
e.errorCode(), facebook::velox::error_code::kSpillLimitExceeded);
}
}
waitForAllTasksToBeDeleted();
}

TEST_F(HashJoinTest, onlyHashBuildMaxSpillBytes) {
Expand Down Expand Up @@ -6494,6 +6498,10 @@ TEST_F(HashJoinTest, reclaimFromJoinBuilderWithMultiDrivers) {
auto& planStats = taskStats.at(result.planNodeId);
ASSERT_GT(planStats.spilledBytes, 0);
result.task.reset();

// This test uses on-demand created memory manager instead of the global
// one. We need to make sure any used memory got cleaned up before exiting
// the scope
waitForAllTasksToBeDeleted();
ASSERT_GT(arbitrator->stats().numRequests, 0);
ASSERT_GT(arbitrator->stats().numReclaimedBytes, 0);
Expand Down Expand Up @@ -6568,6 +6576,10 @@ DEBUG_ONLY_TEST_F(
memoryArbitrationWait.notifyAll();

joinThread.join();

// This test uses on-demand created memory manager instead of the global
// one. We need to make sure any used memory got cleaned up before exiting
// the scope
waitForAllTasksToBeDeleted();
ASSERT_EQ(arbitrator->stats().numNonReclaimableAttempts, 2);
}
Expand Down Expand Up @@ -6652,10 +6664,13 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimFromHashJoinBuildInWaitForTableBuild) {

// We expect the reclaimed bytes from hash build.
ASSERT_GT(arbitrator->stats().numReclaimedBytes, 0);

// This test uses on-demand created memory manager instead of the global
// one. We need to make sure any used memory got cleaned up before exiting
// the scope
waitForAllTasksToBeDeleted();
ASSERT_TRUE(fakeBuffer != nullptr);
fakePool->free(fakeBuffer, kMemoryCapacity);
waitForAllTasksToBeDeleted();
}

DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) {
Expand Down Expand Up @@ -6707,6 +6722,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) {
.assertResults(
"SELECT t.c1 FROM tmp as t, tmp AS u WHERE t.c0 == u.c1 AND t.c1 == u.c0");
ASSERT_TRUE(parallelBuildTriggered);

// This test uses on-demand created memory manager instead of the global
// one. We need to make sure any used memory got cleaned up before exiting
// the scope
waitForAllTasksToBeDeleted();
}

Expand Down Expand Up @@ -6793,6 +6812,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredByEnsureJoinTableFit) {
.assertResults(
"SELECT t.c1 FROM tmp as t, tmp AS u WHERE t.c0 == u.c1 AND t.c1 == u.c0");
task.reset();

// This test uses on-demand created memory manager instead of the global
// one. We need to make sure any used memory got cleaned up before exiting
// the scope
waitForAllTasksToBeDeleted();
ASSERT_EQ(injectAllocations.size(), 2);
}
Expand Down Expand Up @@ -6887,6 +6910,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringJoinTableBuild) {

joinThread.join();
memThread.join();

// This test uses on-demand created memory manager instead of the global
// one. We need to make sure any used memory got cleaned up before exiting
// the scope
waitForAllTasksToBeDeleted();
}

Expand Down Expand Up @@ -6947,6 +6974,11 @@ DEBUG_ONLY_TEST_F(HashJoinTest, joinBuildSpillError) {
waitForAllTasksToBeDeleted();
ASSERT_EQ(arbitrator->stats().numFailures, 1);
ASSERT_EQ(arbitrator->stats().numReserves, 1);

// Wait again here as this test uses on-demand created memory manager instead
// of the global one. We need to make sure any used memory got cleaned up
// before exiting the scope
waitForAllTasksToBeDeleted();
}

DEBUG_ONLY_TEST_F(HashJoinTest, taskWaitTimeout) {
Expand Down Expand Up @@ -7024,6 +7056,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, taskWaitTimeout) {
buildBlockWait.notifyAll();

queryThread.join();

// This test uses on-demand created memory manager instead of the global
// one. We need to make sure any used memory got cleaned up before exiting
// the scope
waitForAllTasksToBeDeleted();
}
}
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/tests/MultiFragmentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ class MultiFragmentTest : public HiveConnectorTestBase {
exec::ExchangeSource::registerFactory(createLocalExchangeSource);
}

void TearDown() override {
waitForAllTasksToBeDeleted();
vectors_.clear();
HiveConnectorTestBase::TearDown();
}

static std::string makeTaskId(const std::string& prefix, int num) {
return fmt::format("local://{}-{}", prefix, num);
}
Expand Down
16 changes: 12 additions & 4 deletions velox/exec/tests/OperatorUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,18 @@
using namespace facebook::velox;
using namespace facebook::velox::test;
using namespace facebook::velox::exec;
using namespace facebook::velox::exec::test;

class OperatorUtilsTest
: public ::facebook::velox::exec::test::OperatorTestBase {
class OperatorUtilsTest : public OperatorTestBase {
protected:
void TearDown() override {
driverCtx_.reset();
driver_.reset();
task_.reset();
waitForAllTasksToBeDeleted();
OperatorTestBase::TearDown();
}

OperatorUtilsTest() {
VectorMaker vectorMaker{pool_.get()};
std::vector<RowVectorPtr> values = {vectorMaker.rowVector(
Expand Down Expand Up @@ -124,8 +132,8 @@ class OperatorUtilsTest
}
}

std::shared_ptr<memory::MemoryPool> pool_{
memory::memoryManager()->addLeafPool()};
// std::shared_ptr<memory::MemoryPool> pool_{
// memory::memoryManager()->addLeafPool()};
std::shared_ptr<Task> task_;
std::shared_ptr<Driver> driver_;
std::unique_ptr<DriverCtx> driverCtx_;
Expand Down
14 changes: 10 additions & 4 deletions velox/exec/tests/SortBufferTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ class SortBufferTest : public OperatorTestBase {
rng_.seed(123);
}

void TearDown() override {
pool_.reset();
rootPool_.reset();
OperatorTestBase::TearDown();
}

common::SpillConfig getSpillConfig(const std::string& spillDir) const {
return common::SpillConfig(
[&]() -> const std::string& { return spillDir; },
Expand Down Expand Up @@ -74,14 +80,14 @@ class SortBufferTest : public OperatorTestBase {
{true, true, false, CompareFlags::NullHandlingMode::kNullAsValue}};

const int64_t maxBytes_ = 20LL << 20; // 20 MB
const std::shared_ptr<memory::MemoryPool> rootPool_{
memory::memoryManager()->addRootPool("SortBufferTest", maxBytes_)};
const std::shared_ptr<memory::MemoryPool> pool_{
rootPool_->addLeafChild("SortBufferTest", maxBytes_)};
const std::shared_ptr<folly::Executor> executor_{
std::make_shared<folly::CPUThreadPoolExecutor>(
std::thread::hardware_concurrency())};

std::shared_ptr<memory::MemoryPool> rootPool_{
memory::memoryManager()->addRootPool("SortBufferTest", maxBytes_)};
std::shared_ptr<memory::MemoryPool> pool_{
rootPool_->addLeafChild("SortBufferTest", maxBytes_)};
tsan_atomic<bool> nonReclaimableSection_{false};
folly::Random::DefaultGenerator rng_;
};
Expand Down
18 changes: 12 additions & 6 deletions velox/exec/tests/SqlTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@ namespace facebook::velox::exec::test {

class SqlTest : public OperatorTestBase {
protected:
void TearDown() override {
planner_.reset();
OperatorTestBase::TearDown();
}

void assertSql(const std::string& sql, const std::string& duckSql = "") {
auto plan = planner_.plan(sql);
auto plan = planner_->plan(sql);
AssertQueryBuilder(plan, duckDbQueryRunner_)
.assertResults(duckSql.empty() ? sql : duckSql);
}

core::DuckDbQueryPlanner planner_{pool()};
std::unique_ptr<core::DuckDbQueryPlanner> planner_{
std::make_unique<core::DuckDbQueryPlanner>(pool())};
};

TEST_F(SqlTest, values) {
Expand All @@ -40,7 +46,7 @@ TEST_F(SqlTest, values) {
}

TEST_F(SqlTest, customScalarFunctions) {
planner_.registerScalarFunction(
planner_->registerScalarFunction(
"array_join", {ARRAY(BIGINT()), VARCHAR()}, VARCHAR());

assertSql("SELECT array_join([1, 2, 3], '-')", "SELECT '1-2-3'");
Expand All @@ -49,7 +55,7 @@ TEST_F(SqlTest, customScalarFunctions) {
TEST_F(SqlTest, customAggregateFunctions) {
// We need an aggregate that DuckDB does not support. 'every' fits the need.
// 'every' is an alias for bool_and().
planner_.registerAggregateFunction("every", {BOOLEAN()}, BOOLEAN());
planner_->registerAggregateFunction("every", {BOOLEAN()}, BOOLEAN());

assertSql(
"SELECT every(x) FROM UNNEST([true, false, true]) as t(x)",
Expand Down Expand Up @@ -81,8 +87,8 @@ TEST_F(SqlTest, tableScan) {
createDuckDbTable("t", data.at("t"));
createDuckDbTable("u", data.at("u"));

planner_.registerTable("t", data.at("t"));
planner_.registerTable("u", data.at("u"));
planner_->registerTable("t", data.at("t"));
planner_->registerTable("u", data.at("u"));

assertSql("SELECT a, avg(b) FROM t WHERE c > 5 GROUP BY 1");
assertSql("SELECT * FROM t, u WHERE t.a = u.a");
Expand Down
6 changes: 6 additions & 0 deletions velox/exec/tests/TaskTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,14 @@ class TestBadMemoryTranslator : public exec::Operator::PlanNodeTranslator {
}
};
} // namespace

class TaskTest : public HiveConnectorTestBase {
protected:
void TearDown() override {
waitForAllTasksToBeDeleted();
HiveConnectorTestBase::TearDown();
}

static std::pair<std::shared_ptr<exec::Task>, std::vector<RowVectorPtr>>
executeSingleThreaded(
core::PlanFragment plan,
Expand Down
7 changes: 7 additions & 0 deletions velox/exec/tests/ValuesTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ namespace facebook::velox::exec::test {

class ValuesTest : public OperatorTestBase {
protected:
void TearDown() override {
waitForAllTasksToBeDeleted();
input_.reset();
input2_.reset();
OperatorTestBase::TearDown();
}

// Sample row vectors.
RowVectorPtr input_{makeRowVector({
makeFlatVector<int32_t>({0, 1, 2, 3, 5}),
Expand Down
32 changes: 22 additions & 10 deletions velox/exec/tests/utils/OperatorTestBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ void OperatorTestBase::SetUpTestCase() {
FLAGS_velox_enable_memory_usage_track_in_default_memory_pool = true;
FLAGS_velox_memory_leak_check_enabled = true;
memory::SharedArbitrator::registerFactory();
resetMemory();
functions::prestosql::registerAllScalarFunctions();
aggregate::prestosql::registerAllAggregateFunctions();
TestValue::enable();
}

void OperatorTestBase::TearDownTestCase() {
asyncDataCache_->shutdown();
waitForAllTasksToBeDeleted();
memory::SharedArbitrator::unregisterFactory();
}

void OperatorTestBase::resetMemory() {
if (asyncDataCache_ != nullptr) {
asyncDataCache_->clear();
asyncDataCache_.reset();
}
MemoryManagerOptions options;
options.allocatorCapacity = 8L << 30;
options.arbitratorCapacity = 6L << 30;
Expand All @@ -75,15 +92,6 @@ void OperatorTestBase::SetUpTestCase() {
asyncDataCache_ =
cache::AsyncDataCache::create(memory::memoryManager()->allocator());
cache::AsyncDataCache::setInstance(asyncDataCache_.get());
functions::prestosql::registerAllScalarFunctions();
aggregate::prestosql::registerAllAggregateFunctions();
TestValue::enable();
}

void OperatorTestBase::TearDownTestCase() {
asyncDataCache_->shutdown();
waitForAllTasksToBeDeleted();
memory::SharedArbitrator::unregisterFactory();
}

void OperatorTestBase::SetUp() {
Expand All @@ -94,7 +102,11 @@ void OperatorTestBase::SetUp() {
ioExecutor_ = std::make_unique<folly::IOThreadPoolExecutor>(3);
}

void OperatorTestBase::TearDown() {}
void OperatorTestBase::TearDown() {
pool_.reset();
rootPool_.reset();
resetMemory();
}

std::shared_ptr<Task> OperatorTestBase::assertQuery(
const core::PlanNodePtr& plan,
Expand Down
2 changes: 2 additions & 0 deletions velox/exec/tests/utils/OperatorTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class OperatorTestBase : public testing::Test,

static void TearDownTestCase();

static void resetMemory();

void createDuckDbTable(const std::vector<RowVectorPtr>& data) {
duckDbQueryRunner_.createTable("tmp", data);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ TEST_F(ApproxPercentileTest, partialFull) {
makeFlatVector<int32_t>(117, [](auto row) { return row < 7 ? 20 : 10; }),
});
exec::test::assertQuery(params, {expected});
waitForAllTasksToBeDeleted();
}

TEST_F(ApproxPercentileTest, finalAggregateAccuracy) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ class MinMaxByAggregationTestBase : public AggregationTestBase {

void SetUp() override;

void TearDown() override {
dataVectorsByType_.clear();
rowVectors_.clear();
AggregationTestBase::TearDown();
}

// Build a flat vector with numeric native type of T. The value in the
// returned flat vector is in ascending order.
template <typename T>
Expand Down

0 comments on commit b9e2eb3

Please sign in to comment.