From 6d273227c763b9b7f397874229a2e7158c10d906 Mon Sep 17 00:00:00 2001 From: Jialiang Tan Date: Mon, 18 Mar 2024 16:24:59 -0700 Subject: [PATCH] Add per-test memory sanity check in OperatorTestBase --- .../memory/tests/MemoryCapExceededTest.cpp | 1 + .../memory/tests/SharedArbitratorTest.cpp | 1 + velox/exec/tests/DriverTest.cpp | 3 ++ velox/exec/tests/HashJoinTest.cpp | 40 ++++++++++++++++++- velox/exec/tests/MultiFragmentTest.cpp | 11 +++++ velox/exec/tests/OperatorUtilsTest.cpp | 14 +++++-- velox/exec/tests/SortBufferTest.cpp | 11 ++--- velox/exec/tests/SqlTest.cpp | 18 ++++++--- velox/exec/tests/TableWriteTest.cpp | 1 + velox/exec/tests/TaskTest.cpp | 6 +++ velox/exec/tests/ValuesTest.cpp | 7 ++++ velox/exec/tests/utils/OperatorTestBase.cpp | 32 ++++++++++----- velox/exec/tests/utils/OperatorTestBase.h | 6 ++- .../aggregates/tests/ApproxPercentileTest.cpp | 1 + .../tests/MinMaxByAggregationTest.cpp | 6 +++ 15 files changed, 130 insertions(+), 28 deletions(-) diff --git a/velox/common/memory/tests/MemoryCapExceededTest.cpp b/velox/common/memory/tests/MemoryCapExceededTest.cpp index 6960d751857d6..848df1e102df1 100644 --- a/velox/common/memory/tests/MemoryCapExceededTest.cpp +++ b/velox/common/memory/tests/MemoryCapExceededTest.cpp @@ -37,6 +37,7 @@ class MemoryCapExceededTest : public OperatorTestBase, } void TearDown() override { + waitForAllTasksToBeDeleted(); OperatorTestBase::TearDown(); FLAGS_velox_suppress_memory_capacity_exceeding_error_message = false; } diff --git a/velox/common/memory/tests/SharedArbitratorTest.cpp b/velox/common/memory/tests/SharedArbitratorTest.cpp index a0b5fc5c90fdf..56dc6f04c12f7 100644 --- a/velox/common/memory/tests/SharedArbitratorTest.cpp +++ b/velox/common/memory/tests/SharedArbitratorTest.cpp @@ -257,6 +257,7 @@ class SharedArbitrationTest : public exec::test::HiveConnectorTestBase { } void TearDown() override { + vector_.reset(); HiveConnectorTestBase::TearDown(); } diff --git a/velox/exec/tests/DriverTest.cpp b/velox/exec/tests/DriverTest.cpp index 64f3492acc99d..df5986d420309 100644 --- a/velox/exec/tests/DriverTest.cpp +++ b/velox/exec/tests/DriverTest.cpp @@ -100,6 +100,7 @@ class DriverTest : public OperatorTestBase { // NOTE: destroy the tasks first to release all the allocated memory held // by the plan nodes (Values) in tasks. tasks_.clear(); + waitForAllTasksToBeDeleted(); if (wakeupInitialized_) { wakeupCancelled_ = true; @@ -1480,4 +1481,6 @@ TEST_F(OpCallStatusTest, basic) { task->start(1, 1); ASSERT_TRUE(waitForTaskCompletion(task.get(), 600'000'000)); + task.reset(); + waitForAllTasksToBeDeleted(); }; diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 06920f5cb3171..fd5d9858a3da2 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -792,6 +792,11 @@ class HashJoinTest : public HiveConnectorTestBase { .allowLazyVector = false}; } + void TearDown() override { + waitForAllTasksToBeDeleted(); + HiveConnectorTestBase::TearDown(); + } + // Make splits with each plan node having a number of source files. SplitInput makeSpiltInput( const std::vector& nodeIds, @@ -6408,7 +6413,6 @@ TEST_F(HashJoinTest, maxSpillBytes) { e.errorCode(), facebook::velox::error_code::kSpillLimitExceeded); } } - waitForAllTasksToBeDeleted(); } TEST_F(HashJoinTest, onlyHashBuildMaxSpillBytes) { @@ -6503,6 +6507,10 @@ TEST_F(HashJoinTest, reclaimFromJoinBuilderWithMultiDrivers) { auto& planStats = taskStats.at(result.planNodeId); ASSERT_GT(planStats.spilledBytes, 0); result.task.reset(); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope waitForAllTasksToBeDeleted(); ASSERT_GT(arbitrator->stats().numRequests, 0); ASSERT_GT(arbitrator->stats().numReclaimedBytes, 0); @@ -6577,6 +6585,10 @@ DEBUG_ONLY_TEST_F( memoryArbitrationWait.notifyAll(); joinThread.join(); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope waitForAllTasksToBeDeleted(); ASSERT_EQ(arbitrator->stats().numNonReclaimableAttempts, 2); } @@ -6661,10 +6673,13 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimFromHashJoinBuildInWaitForTableBuild) { // We expect the reclaimed bytes from hash build. ASSERT_GT(arbitrator->stats().numReclaimedBytes, 0); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope waitForAllTasksToBeDeleted(); ASSERT_TRUE(fakeBuffer != nullptr); fakePool->free(fakeBuffer, kMemoryCapacity); - waitForAllTasksToBeDeleted(); } DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) { @@ -6716,6 +6731,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredDuringParallelJoinBuild) { .assertResults( "SELECT t.c1 FROM tmp as t, tmp AS u WHERE t.c0 == u.c1 AND t.c1 == u.c0"); ASSERT_TRUE(parallelBuildTriggered); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope waitForAllTasksToBeDeleted(); } @@ -6802,6 +6821,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, arbitrationTriggeredByEnsureJoinTableFit) { .assertResults( "SELECT t.c1 FROM tmp as t, tmp AS u WHERE t.c0 == u.c1 AND t.c1 == u.c0"); task.reset(); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope waitForAllTasksToBeDeleted(); ASSERT_EQ(injectAllocations.size(), 2); } @@ -6896,6 +6919,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringJoinTableBuild) { joinThread.join(); memThread.join(); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope waitForAllTasksToBeDeleted(); } @@ -6956,6 +6983,11 @@ DEBUG_ONLY_TEST_F(HashJoinTest, joinBuildSpillError) { waitForAllTasksToBeDeleted(); ASSERT_EQ(arbitrator->stats().numFailures, 1); ASSERT_EQ(arbitrator->stats().numReserves, 1); + + // Wait again here as this test uses on-demand created memory manager instead + // of the global one. We need to make sure any used memory got cleaned up + // before exiting the scope + waitForAllTasksToBeDeleted(); } DEBUG_ONLY_TEST_F(HashJoinTest, taskWaitTimeout) { @@ -7033,6 +7065,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, taskWaitTimeout) { buildBlockWait.notifyAll(); queryThread.join(); + + // This test uses on-demand created memory manager instead of the global + // one. We need to make sure any used memory got cleaned up before exiting + // the scope waitForAllTasksToBeDeleted(); } } diff --git a/velox/exec/tests/MultiFragmentTest.cpp b/velox/exec/tests/MultiFragmentTest.cpp index 1e9af9fc7f840..3e8f4c03517b2 100644 --- a/velox/exec/tests/MultiFragmentTest.cpp +++ b/velox/exec/tests/MultiFragmentTest.cpp @@ -46,6 +46,17 @@ class MultiFragmentTest : public HiveConnectorTestBase { exec::ExchangeSource::registerFactory(createLocalExchangeSource); } + void TearDown() override { + waitForAllTasksToBeDeleted(); + + // There might be lingering exchange source on executor even after all tasks + // are deleted. This can cause memory leak because exchange source holds + // reference to memory pool. We need to make sure they are properly cleaned. + testingShutdownLocalExchangeSource(); + vectors_.clear(); + HiveConnectorTestBase::TearDown(); + } + static std::string makeTaskId(const std::string& prefix, int num) { return fmt::format("local://{}-{}", prefix, num); } diff --git a/velox/exec/tests/OperatorUtilsTest.cpp b/velox/exec/tests/OperatorUtilsTest.cpp index 09fff638f09d3..24ce90dd7480f 100644 --- a/velox/exec/tests/OperatorUtilsTest.cpp +++ b/velox/exec/tests/OperatorUtilsTest.cpp @@ -23,10 +23,18 @@ using namespace facebook::velox; using namespace facebook::velox::test; using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; -class OperatorUtilsTest - : public ::facebook::velox::exec::test::OperatorTestBase { +class OperatorUtilsTest : public OperatorTestBase { protected: + void TearDown() override { + driverCtx_.reset(); + driver_.reset(); + task_.reset(); + waitForAllTasksToBeDeleted(); + OperatorTestBase::TearDown(); + } + OperatorUtilsTest() { VectorMaker vectorMaker{pool_.get()}; std::vector values = {vectorMaker.rowVector( @@ -124,8 +132,6 @@ class OperatorUtilsTest } } - std::shared_ptr pool_{ - memory::memoryManager()->addLeafPool()}; std::shared_ptr task_; std::shared_ptr driver_; std::unique_ptr driverCtx_; diff --git a/velox/exec/tests/SortBufferTest.cpp b/velox/exec/tests/SortBufferTest.cpp index 83a23fbefeea6..771b16cad06df 100644 --- a/velox/exec/tests/SortBufferTest.cpp +++ b/velox/exec/tests/SortBufferTest.cpp @@ -42,6 +42,12 @@ class SortBufferTest : public OperatorTestBase { rng_.seed(123); } + void TearDown() override { + pool_.reset(); + rootPool_.reset(); + OperatorTestBase::TearDown(); + } + common::SpillConfig getSpillConfig(const std::string& spillDir) const { return common::SpillConfig( [&]() -> const std::string& { return spillDir; }, @@ -73,11 +79,6 @@ class SortBufferTest : public OperatorTestBase { {true, true, false, CompareFlags::NullHandlingMode::kNullAsValue}, {true, true, false, CompareFlags::NullHandlingMode::kNullAsValue}}; - const int64_t maxBytes_ = 20LL << 20; // 20 MB - const std::shared_ptr rootPool_{ - memory::memoryManager()->addRootPool("SortBufferTest", maxBytes_)}; - const std::shared_ptr pool_{ - rootPool_->addLeafChild("SortBufferTest", maxBytes_)}; const std::shared_ptr executor_{ std::make_shared( std::thread::hardware_concurrency())}; diff --git a/velox/exec/tests/SqlTest.cpp b/velox/exec/tests/SqlTest.cpp index d318f5a9d4ee0..8ad448474e9e0 100644 --- a/velox/exec/tests/SqlTest.cpp +++ b/velox/exec/tests/SqlTest.cpp @@ -21,13 +21,19 @@ namespace facebook::velox::exec::test { class SqlTest : public OperatorTestBase { protected: + void TearDown() override { + planner_.reset(); + OperatorTestBase::TearDown(); + } + void assertSql(const std::string& sql, const std::string& duckSql = "") { - auto plan = planner_.plan(sql); + auto plan = planner_->plan(sql); AssertQueryBuilder(plan, duckDbQueryRunner_) .assertResults(duckSql.empty() ? sql : duckSql); } - core::DuckDbQueryPlanner planner_{pool()}; + std::unique_ptr planner_{ + std::make_unique(pool())}; }; TEST_F(SqlTest, values) { @@ -40,7 +46,7 @@ TEST_F(SqlTest, values) { } TEST_F(SqlTest, customScalarFunctions) { - planner_.registerScalarFunction( + planner_->registerScalarFunction( "array_join", {ARRAY(BIGINT()), VARCHAR()}, VARCHAR()); assertSql("SELECT array_join([1, 2, 3], '-')", "SELECT '1-2-3'"); @@ -49,7 +55,7 @@ TEST_F(SqlTest, customScalarFunctions) { TEST_F(SqlTest, customAggregateFunctions) { // We need an aggregate that DuckDB does not support. 'every' fits the need. // 'every' is an alias for bool_and(). - planner_.registerAggregateFunction("every", {BOOLEAN()}, BOOLEAN()); + planner_->registerAggregateFunction("every", {BOOLEAN()}, BOOLEAN()); assertSql( "SELECT every(x) FROM UNNEST([true, false, true]) as t(x)", @@ -81,8 +87,8 @@ TEST_F(SqlTest, tableScan) { createDuckDbTable("t", data.at("t")); createDuckDbTable("u", data.at("u")); - planner_.registerTable("t", data.at("t")); - planner_.registerTable("u", data.at("u")); + planner_->registerTable("t", data.at("t")); + planner_->registerTable("u", data.at("u")); assertSql("SELECT a, avg(b) FROM t WHERE c > 5 GROUP BY 1"); assertSql("SELECT * FROM t, u WHERE t.a = u.a"); diff --git a/velox/exec/tests/TableWriteTest.cpp b/velox/exec/tests/TableWriteTest.cpp index bf4ad2e264862..ee114a1ec01b4 100644 --- a/velox/exec/tests/TableWriteTest.cpp +++ b/velox/exec/tests/TableWriteTest.cpp @@ -248,6 +248,7 @@ class TableWriteTest : public HiveConnectorTestBase { } void TearDown() override { + waitForAllTasksToBeDeleted(); HiveConnectorTestBase::TearDown(); } diff --git a/velox/exec/tests/TaskTest.cpp b/velox/exec/tests/TaskTest.cpp index e2fb897b6b2d5..3bf0a1d487a6b 100644 --- a/velox/exec/tests/TaskTest.cpp +++ b/velox/exec/tests/TaskTest.cpp @@ -457,8 +457,14 @@ class TestBadMemoryTranslator : public exec::Operator::PlanNodeTranslator { } }; } // namespace + class TaskTest : public HiveConnectorTestBase { protected: + void TearDown() override { + waitForAllTasksToBeDeleted(); + HiveConnectorTestBase::TearDown(); + } + static std::pair, std::vector> executeSingleThreaded( core::PlanFragment plan, diff --git a/velox/exec/tests/ValuesTest.cpp b/velox/exec/tests/ValuesTest.cpp index fc58e6351470a..f9d79872af45c 100644 --- a/velox/exec/tests/ValuesTest.cpp +++ b/velox/exec/tests/ValuesTest.cpp @@ -25,6 +25,13 @@ namespace facebook::velox::exec::test { class ValuesTest : public OperatorTestBase { protected: + void TearDown() override { + waitForAllTasksToBeDeleted(); + input_.reset(); + input2_.reset(); + OperatorTestBase::TearDown(); + } + // Sample row vectors. RowVectorPtr input_{makeRowVector({ makeFlatVector({0, 1, 2, 3, 5}), diff --git a/velox/exec/tests/utils/OperatorTestBase.cpp b/velox/exec/tests/utils/OperatorTestBase.cpp index d0366ef8d7845..7ad53cc8889a4 100644 --- a/velox/exec/tests/utils/OperatorTestBase.cpp +++ b/velox/exec/tests/utils/OperatorTestBase.cpp @@ -64,6 +64,23 @@ void OperatorTestBase::SetUpTestCase() { FLAGS_velox_enable_memory_usage_track_in_default_memory_pool = true; FLAGS_velox_memory_leak_check_enabled = true; memory::SharedArbitrator::registerFactory(); + resetMemory(); + functions::prestosql::registerAllScalarFunctions(); + aggregate::prestosql::registerAllAggregateFunctions(); + TestValue::enable(); +} + +void OperatorTestBase::TearDownTestCase() { + asyncDataCache_->shutdown(); + waitForAllTasksToBeDeleted(); + memory::SharedArbitrator::unregisterFactory(); +} + +void OperatorTestBase::resetMemory() { + if (asyncDataCache_ != nullptr) { + asyncDataCache_->clear(); + asyncDataCache_.reset(); + } MemoryManagerOptions options; options.allocatorCapacity = 8L << 30; options.arbitratorCapacity = 6L << 30; @@ -75,15 +92,6 @@ void OperatorTestBase::SetUpTestCase() { asyncDataCache_ = cache::AsyncDataCache::create(memory::memoryManager()->allocator()); cache::AsyncDataCache::setInstance(asyncDataCache_.get()); - functions::prestosql::registerAllScalarFunctions(); - aggregate::prestosql::registerAllAggregateFunctions(); - TestValue::enable(); -} - -void OperatorTestBase::TearDownTestCase() { - asyncDataCache_->shutdown(); - waitForAllTasksToBeDeleted(); - memory::SharedArbitrator::unregisterFactory(); } void OperatorTestBase::SetUp() { @@ -94,7 +102,11 @@ void OperatorTestBase::SetUp() { ioExecutor_ = std::make_unique(3); } -void OperatorTestBase::TearDown() {} +void OperatorTestBase::TearDown() { + pool_.reset(); + rootPool_.reset(); + resetMemory(); +} std::shared_ptr OperatorTestBase::assertQuery( const core::PlanNodePtr& plan, diff --git a/velox/exec/tests/utils/OperatorTestBase.h b/velox/exec/tests/utils/OperatorTestBase.h index f7d5bf91782f3..e1fcd091da1e4 100644 --- a/velox/exec/tests/utils/OperatorTestBase.h +++ b/velox/exec/tests/utils/OperatorTestBase.h @@ -33,13 +33,17 @@ namespace facebook::velox::exec::test { class OperatorTestBase : public testing::Test, public velox::test::VectorTestBase { public: - /// The following two methods are used by google unit test framework to do + /// The following methods are used by google unit test framework to do /// one-time setup/teardown for all the unit tests from OperatorTestBase. We /// make them public as some benchmark like ReduceAgg also call these methods /// to setup/teardown benchmark test environment. static void SetUpTestCase(); static void TearDownTestCase(); + /// Sets up the velox memory system. A second call to this will clear the + /// previous memory system instances and create a new set. + static void resetMemory(); + protected: OperatorTestBase(); ~OperatorTestBase() override; diff --git a/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp b/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp index 48065a1a30f49..6e131c1c6190a 100644 --- a/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ApproxPercentileTest.cpp @@ -379,6 +379,7 @@ TEST_F(ApproxPercentileTest, partialFull) { makeFlatVector(117, [](auto row) { return row < 7 ? 20 : 10; }), }); exec::test::assertQuery(params, {expected}); + waitForAllTasksToBeDeleted(); } TEST_F(ApproxPercentileTest, finalAggregateAccuracy) { diff --git a/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp b/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp index 2ea761de486b4..2f5cd38b0fb12 100644 --- a/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp @@ -144,6 +144,12 @@ class MinMaxByAggregationTestBase : public AggregationTestBase { void SetUp() override; + void TearDown() override { + dataVectorsByType_.clear(); + rowVectors_.clear(); + AggregationTestBase::TearDown(); + } + // Build a flat vector with numeric native type of T. The value in the // returned flat vector is in ascending order. template