diff --git a/sdk/src/common/random.cc b/sdk/src/common/random.cc index 77b88cfa2a..9b696dea65 100644 --- a/sdk/src/common/random.cc +++ b/sdk/src/common/random.cc @@ -5,6 +5,7 @@ #include "src/common/random.h" #include "src/common/platform/fork.h" +#include #include #include @@ -26,12 +27,17 @@ class TlsRandomNumberGenerator TlsRandomNumberGenerator() noexcept { Seed(); - platform::AtFork(nullptr, nullptr, OnFork); + if (!flag.test_and_set()) + { + platform::AtFork(nullptr, nullptr, OnFork); + } } static FastRandomNumberGenerator &engine() noexcept { return engine_; } private: + static std::atomic_flag flag; + static thread_local FastRandomNumberGenerator engine_; static void OnFork() noexcept { Seed(); } @@ -44,6 +50,7 @@ class TlsRandomNumberGenerator } }; +std::atomic_flag TlsRandomNumberGenerator::flag; thread_local FastRandomNumberGenerator TlsRandomNumberGenerator::engine_{}; } // namespace diff --git a/sdk/test/common/random_test.cc b/sdk/test/common/random_test.cc index 8132b2d5a4..243a998041 100644 --- a/sdk/test/common/random_test.cc +++ b/sdk/test/common/random_test.cc @@ -4,7 +4,10 @@ #include "src/common/random.h" #include +#include #include +#include +#include #include using opentelemetry::sdk::common::Random; @@ -34,3 +37,27 @@ TEST(RandomTest, GenerateRandomBuffer) std::equal(std::begin(buf1_vector), std::end(buf1_vector), std::begin(buf2_vector))); } } + +void doSomethingOnce(std::atomic_uint *count) +{ + static std::atomic_flag flag; + if (!flag.test_and_set()) + { + (*count)++; + } +} + +TEST(RandomTest, AtomicFlagMultiThreadTest) +{ + std::vector threads; + std::atomic_uint count(0); + for (int i = 0; i < 10; ++i) + { + threads.push_back(std::thread(doSomethingOnce, &count)); + } + for (auto &t : threads) + { + t.join(); + } + EXPECT_EQ(1, count.load()); +}