Skip to content

Commit

Permalink
test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 24, 2024
1 parent ca82504 commit cc673a6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
9 changes: 0 additions & 9 deletions src/common/device_vector.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocato
*/
class LoggingResource : public rmm::mr::device_memory_resource {
rmm::mr::device_memory_resource *mr_{rmm::mr::get_current_device_resource()};
std::mutex lock_;

public:
LoggingResource() = default;
Expand All @@ -407,10 +406,6 @@ class LoggingResource : public rmm::mr::device_memory_resource {
}

void *do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override { // NOLINT
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
guard.lock();
}
try {
auto const ptr = mr_->allocate(bytes, stream);
GlobalMemoryLogger().RegisterAllocation(ptr, bytes);
Expand All @@ -423,10 +418,6 @@ class LoggingResource : public rmm::mr::device_memory_resource {

void do_deallocate(void *ptr, std::size_t bytes, // NOLINT
rmm::cuda_stream_view stream) override {
std::unique_lock<std::mutex> guard{lock_, std::defer_lock};
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
guard.lock();
}
mr_->deallocate(ptr, bytes, stream);
GlobalMemoryLogger().RegisterDeallocation(ptr, bytes);
}
Expand Down
19 changes: 19 additions & 0 deletions tests/cpp/common/test_device_vector.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright 2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <thread> // for thread

#include <numeric> // for iota
#include <thrust/detail/sequence.inl> // for sequence
Expand Down Expand Up @@ -115,4 +116,22 @@ TEST(TestVirtualMem, Version) {
ASSERT_FALSE(pinned.IsVm());
}
}

TEST(AtomitFetch, Max) {
auto n_threads = std::thread::hardware_concurrency();
std::vector<std::thread> threads;
std::atomic<std::int64_t> n{0};
decltype(n)::value_type add = 64;
for (decltype(n_threads) t = 0; t < n_threads; ++t) {
threads.emplace_back([=, &n] {
for (std::size_t i = 0; i < add; ++i) {
detail::AtomicFetchMax(n, static_cast<std::int64_t>(t + i));
}
});
}
for (auto& t : threads) {
t.join();
}
ASSERT_EQ(n, n_threads - 1 + add - 1); // 0-based indexing
}
} // namespace dh

0 comments on commit cc673a6

Please sign in to comment.