Skip to content

Commit

Permalink
#Centipede Distiller: Port writing threads to ThreadPool
Browse files Browse the repository at this point in the history
Motivations:
- Primary: Cap the overall parallelism, which is important now that reading is parallelized.
- Secondary: Use uniform APIs for similar things, and use less boilerplate in the process.

PiperOrigin-RevId: 596407826
  • Loading branch information
ussuri authored and copybara-github committed Feb 6, 2024
1 parent 9d0386c commit 122a6e5
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions centipede/distill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <sstream>
#include <string>
#include <string_view>
#include <thread> // NOLINT(build/c++11)
#include <utility>
#include <vector>

Expand Down Expand Up @@ -78,6 +77,14 @@ inline constexpr perf::MemSize kGB = 1024L * 1024L * 1024L;
// The maximum number of threads reading input shards concurrently. This is
// mainly to prevent I/O congestion.
inline constexpr size_t kMaxReadingThreads = 100;
// The maximum number of threads writing shards concurrently. These in turn
// launch up to `kMaxReadingThreads` reading threads.
inline constexpr size_t kMaxWritingThreads = 10;
// A global cap on the total number of threads, both writing and reading. Unlike
// the other two limits, this one is purely to prevent too many threads in the
// process.
inline constexpr size_t kMaxTotalThreads = 1000;
static_assert(kMaxReadingThreads * kMaxWritingThreads <= kMaxTotalThreads);

std::string LogPrefix(const Environment &env) {
return absl::StrCat("DISTILL[S.", env.my_shard_index, "]: ");
Expand Down Expand Up @@ -304,28 +311,34 @@ int Distill(const Environment &env) {
constexpr perf::RUsageMemory kRamQuota{.mem_rss = 25 * kGB};
perf::ResourcePool ram_pool{kRamQuota};

// Run `env.num_threads` independent distillation threads.
std::vector<std::thread> threads(env.num_threads);
std::vector<Environment> envs(env.num_threads, env);
std::vector<Environment> envs_per_thread(env.num_threads, env);
std::vector<std::vector<size_t>> shard_indices_per_thread(env.num_threads);
// Start the threads.
// Prepare per-thread envs and input shard indices.
for (size_t thread_idx = 0; thread_idx < env.num_threads; ++thread_idx) {
envs[thread_idx].my_shard_index += thread_idx;
envs_per_thread[thread_idx].my_shard_index += thread_idx;
// Shuffle the shards, so that every thread produces different result.
Rng rng(GetRandomSeed(env.seed + thread_idx));
auto &shard_indices = shard_indices_per_thread[thread_idx];
shard_indices.resize(env.total_shards);
std::iota(shard_indices.begin(), shard_indices.end(), 0);
std::shuffle(shard_indices.begin(), shard_indices.end(), rng);
// Run the thread.
threads[thread_idx] =
std::thread(DistillTask, std::ref(envs[thread_idx]), shard_indices,
std::ref(ram_pool), kMaxReadingThreads);
}
// Join threads.
for (size_t thread_idx = 0; thread_idx < env.num_threads; thread_idx++) {
threads[thread_idx].join();
}

// Run the distillation threads in parallel.
{
const size_t num_threads = std::min(env.num_threads, kMaxWritingThreads);
ThreadPool threads{static_cast<int>(num_threads)};
for (size_t thread_idx = 0; thread_idx < env.num_threads; ++thread_idx) {
threads.Schedule(
[&thread_env = envs_per_thread[thread_idx],
&thread_shard_indices = shard_indices_per_thread[thread_idx],
&ram_pool]() {
DistillTask( //
thread_env, thread_shard_indices, ram_pool, kMaxReadingThreads);
});
}
} // The writing threads join here.

return EXIT_SUCCESS;
}

Expand Down

0 comments on commit 122a6e5

Please sign in to comment.