diff --git a/centipede/distill.cc b/centipede/distill.cc index 4c6a8e359..3cea93227 100644 --- a/centipede/distill.cc +++ b/centipede/distill.cc @@ -24,7 +24,6 @@ #include #include #include -#include // NOLINT(build/c++11) #include #include @@ -78,6 +77,14 @@ inline constexpr perf::MemSize kGB = 1024L * 1024L * 1024L; // The maximum number of threads reading input shards concurrently. This is // mainly to prevent I/O congestion. inline constexpr size_t kMaxReadingThreads = 100; +// The maximum number of threads writing shards concurrently. These in turn +// launch up to `kMaxReadingThreads` reading threads. +inline constexpr size_t kMaxWritingThreads = 10; +// A global cap on the total number of threads, both writing and reading. Unlike +// the other two limits, this one is purely to prevent too many threads in the +// process. +inline constexpr size_t kMaxTotalThreads = 1000; +static_assert(kMaxReadingThreads * kMaxWritingThreads <= kMaxTotalThreads); std::string LogPrefix(const Environment &env) { return absl::StrCat("DISTILL[S.", env.my_shard_index, "]: "); @@ -304,28 +311,34 @@ int Distill(const Environment &env) { constexpr perf::RUsageMemory kRamQuota{.mem_rss = 25 * kGB}; perf::ResourcePool ram_pool{kRamQuota}; - // Run `env.num_threads` independent distillation threads. - std::vector threads(env.num_threads); - std::vector envs(env.num_threads, env); + std::vector envs_per_thread(env.num_threads, env); std::vector> shard_indices_per_thread(env.num_threads); - // Start the threads. + // Prepare per-thread envs and input shard indices. for (size_t thread_idx = 0; thread_idx < env.num_threads; ++thread_idx) { - envs[thread_idx].my_shard_index += thread_idx; + envs_per_thread[thread_idx].my_shard_index += thread_idx; // Shuffle the shards, so that every thread produces different result. Rng rng(GetRandomSeed(env.seed + thread_idx)); auto &shard_indices = shard_indices_per_thread[thread_idx]; shard_indices.resize(env.total_shards); std::iota(shard_indices.begin(), shard_indices.end(), 0); std::shuffle(shard_indices.begin(), shard_indices.end(), rng); - // Run the thread. - threads[thread_idx] = - std::thread(DistillTask, std::ref(envs[thread_idx]), shard_indices, - std::ref(ram_pool), kMaxReadingThreads); - } - // Join threads. - for (size_t thread_idx = 0; thread_idx < env.num_threads; thread_idx++) { - threads[thread_idx].join(); } + + // Run the distillation threads in parallel. + { + const size_t num_threads = std::min(env.num_threads, kMaxWritingThreads); + ThreadPool threads{static_cast(num_threads)}; + for (size_t thread_idx = 0; thread_idx < env.num_threads; ++thread_idx) { + threads.Schedule( + [&thread_env = envs_per_thread[thread_idx], + &thread_shard_indices = shard_indices_per_thread[thread_idx], + &ram_pool]() { + DistillTask( // + thread_env, thread_shard_indices, ram_pool, kMaxReadingThreads); + }); + } + } // The writing threads join here. + return EXIT_SUCCESS; }