Skip to content

Commit

Permalink
#Centipede Distiller: Port writing threads to ThreadPool
Browse files Browse the repository at this point in the history
Motivations:
- Primary: Cap the overall parallelism, which is important now that reading is parallelized.
- Secondary: Use uniform APIs for similar things, and use less boilerplate in the process.

PiperOrigin-RevId: 596407826
  • Loading branch information
ussuri authored and copybara-github committed Jan 26, 2024
1 parent a53a208 commit 564b831
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 73 deletions.
4 changes: 4 additions & 0 deletions centipede/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -897,11 +897,14 @@ cc_library(
":logging",
":rusage_profiler",
":shard_reader",
":thread_pool",
":util",
":workdir",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
],
)
Expand Down Expand Up @@ -1235,6 +1238,7 @@ cc_test(
deps = [
":remote_file",
":test_util",
"@com_google_absl//absl/log:check",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
279 changes: 207 additions & 72 deletions centipede/distill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@
#include "./centipede/distill.h"

#include <algorithm>
#include <atomic>
#include <cstddef>
#include <cstdlib>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <thread> // NOLINT(build/c++11)
#include <string_view>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "./centipede/blob_file.h"
#include "./centipede/defs.h"
Expand All @@ -35,83 +40,208 @@
#include "./centipede/logging.h"
#include "./centipede/rusage_profiler.h"
#include "./centipede/shard_reader.h"
#include "./centipede/thread_pool.h"
#include "./centipede/util.h"
#include "./centipede/workdir.h"

namespace centipede {

using CorpusElt = std::pair<ByteArray, FeatureVec>;
using CorpusEltVec = std::vector<CorpusElt>;
namespace {

void DistillTask(const Environment &env,
const std::vector<size_t> &shard_indices) {
const std::string log_line =
absl::StrCat("DISTILL[S.", env.my_shard_index, "]: ");
struct CorpusElt {
ByteArray input;
FeatureVec features;

const WorkDir wd{env};
const auto corpus_path = wd.DistilledCorpusFiles().MyShardPath();
const auto features_path = wd.DistilledFeaturesFiles().MyShardPath();
LOG(INFO) << log_line << VV(env.total_shards) << VV(corpus_path)
<< VV(features_path);
ByteArray PackedFeatures() const {
return PackFeaturesAndHash(input, features);
}
};

using CorpusEltVec = std::vector<CorpusElt>;

const auto corpus_writer = DefaultBlobFileWriterFactory(env.riegeli);
const auto features_writer = DefaultBlobFileWriterFactory(env.riegeli);
// NOTE: Overwrite distilled corpus and features files -- do not append.
CHECK_OK(corpus_writer->Open(corpus_path, "w"));
CHECK_OK(features_writer->Open(features_path, "w"));
// The maximum number of threads reading input shards concurrently. This is
// mainly to prevent I/O congestion.
// TODO(ussuri): Bump up significantly when RSS-gated mutexing is in.
inline constexpr size_t kMaxReadingThreads = 1;
// The maximum number of threads writing shards concurrently. These in turn
// launch up to `kMaxReadingThreads` reading threads.
inline constexpr size_t kMaxWritingThreads = 10;
// A global cap on the total number of threads, both writing and reading. Unlike
// the other two limits, this one is purely to prevent too many threads in the
// process.
inline constexpr size_t kMaxTotalThreads = 1000;
static_assert(kMaxReadingThreads * kMaxWritingThreads <= kMaxTotalThreads);

FeatureSet feature_set(/*frequency_threshold=*/1,
env.MakeDomainDiscardMask());
std::string LogPrefix(const Environment &env) {
return absl::StrCat("DISTILL[S.", env.my_shard_index, "]: ");
}

const size_t num_shards = shard_indices.size();
size_t num_read_shards = 0;
size_t num_read_elements = 0;
size_t num_distilled_elements = 0;
const auto corpus_files = wd.CorpusFiles();
const auto features_files = wd.FeaturesFiles();
// TODO(ussuri): Move the reader/writer classes to shard_reader.cc, rename it
// to corpus_io.cc, and reuse the new APIs where useful in the code base.

for (size_t shard_idx : shard_indices) {
const std::string corpus_path = corpus_files.ShardPath(shard_idx);
const std::string features_path = features_files.ShardPath(shard_idx);
// A helper class for reading input corpus shards. Thread-safe.
class InputCorpusShardReader {
public:
InputCorpusShardReader(const Environment &env) : env_{env} {}

VLOG(2) << log_line << "reading input shard " << shard_idx << ":\n"
// Reads and returns a single shard's elements. Thread-safe.
CorpusEltVec ReadShard(size_t shard_idx) {
const WorkDir wd{env_};
const auto corpus_path = wd.CorpusFiles().ShardPath(shard_idx);
const auto features_path = wd.FeaturesFiles().ShardPath(shard_idx);
VLOG(1) << LogPrefix(env_) << "reading input shard " << shard_idx << ":\n"
<< VV(corpus_path) << "\n"
<< VV(features_path);

CorpusEltVec elts;
// Read elements from the current shard.
CorpusEltVec shard_elts;
ReadShard(corpus_path, features_path,
[&shard_elts](const ByteArray &input, FeatureVec &features) {
shard_elts.emplace_back(input, std::move(features));
});
// Reverse the order of inputs read from the current shard.
// The intuition is as follows:
// * If the shard is the result of fuzzing with Centipede, the inputs that
// are closer to the end are more interesting, so we start there.
// * If the shard resulted from somethening else, the reverse order is not
// any better or worse than any other order.
std::reverse(shard_elts.begin(), shard_elts.end());
++num_read_shards;

// Iterate the elts, add those that have new features.
// This is a simple linear greedy set cover algorithm.
VLOG(1) << log_line << "appending elements from input shard " << shard_idx
<< " to output shard";
for (auto &[input, features] : shard_elts) {
++num_read_elements;
feature_set.PruneDiscardedDomains(features);
if (!feature_set.HasUnseenFeatures(features)) continue;
feature_set.IncrementFrequencies(features);
centipede::ReadShard( //
corpus_path, features_path,
[&elts](const ByteArray &input, FeatureVec &features) {
elts.emplace_back(input, std::move(features));
});
++num_read_shards_;
return elts;
}

size_t num_read_shards() const { return num_read_shards_; }

private:
Environment env_;
std::atomic<size_t> num_read_shards_ = 0;
};

// A helper class for writing corpus shards. Thread-safe by virtue of enforcing
// exclusive locking in the function annotations.
class CorpusShardWriter {
public:
CorpusShardWriter(const Environment &env, std::string_view mode)
: env_{env},
corpus_writer_{DefaultBlobFileWriterFactory()},
feature_writer_{DefaultBlobFileWriterFactory()} {
const WorkDir wd{env};
corpus_path_ = wd.DistilledCorpusFiles().MyShardPath();
features_path_ = wd.DistilledFeaturesFiles().MyShardPath();
CHECK_OK(corpus_writer_->Open(corpus_path_, mode));
CHECK_OK(feature_writer_->Open(features_path_, mode));
}

virtual ~CorpusShardWriter() = default;

void WriteElt(CorpusElt elt) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
++num_total_elts_;
if (PreprocessElt(elt) == EltDisposition::kWrite) {
// Append to the distilled corpus and features files.
CHECK_OK(corpus_writer->Write(input));
CHECK_OK(features_writer->Write(PackFeaturesAndHash(input, features)));
++num_distilled_elements;
VLOG_EVERY_N(10, 1000) << VV(num_distilled_elements);
CHECK_OK(corpus_writer_->Write(elt.input));
CHECK_OK(feature_writer_->Write(elt.PackedFeatures()));
++num_written_elts_;
}
}

void WriteBatch(CorpusEltVec elts) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(1) << LogPrefix(env_) << "writing " << elts.size()
<< " elements to output shard:\n"
<< VV(corpus_path_) << "\n"
<< VV(features_path_);
for (auto &elt : elts) {
WriteElt(std::move(elt));
}
LOG(INFO) << log_line << feature_set << " src_shards: " << num_read_shards
<< "/" << num_shards << " src_elts: " << num_read_elements
<< " dist_elts: " << num_distilled_elements;
++num_written_batches_;
}

size_t num_total_elts() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return num_total_elts_;
}
size_t num_written_elts() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return num_written_elts_;
}
size_t num_written_batches() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return num_written_batches_;
}

absl::Mutex &Mutex() ABSL_LOCK_RETURNED(mu_) { return mu_; }

protected:
[[nodiscard]] enum class EltDisposition { kWrite, kIgnore };

// A behavior customization point: a derived class gets an opportunity to
// analyze and/or preprocess `elt` before it is written. In particular, the
// derived class can choose to skip writing entirely by returning `kIgnore`.
virtual EltDisposition PreprocessElt(CorpusElt &elt) {
return EltDisposition::kWrite;
}

private:
Environment env_;
std::string corpus_path_;
std::string features_path_;

absl::Mutex mu_;

std::unique_ptr<BlobFileWriter> corpus_writer_ ABSL_GUARDED_BY(mu_);
std::unique_ptr<BlobFileWriter> feature_writer_ ABSL_GUARDED_BY(mu_);
size_t num_total_elts_ ABSL_GUARDED_BY(mu_) = 0;
size_t num_written_elts_ ABSL_GUARDED_BY(mu_) = 0;
size_t num_written_batches_ ABSL_GUARDED_BY(mu_) = 0;
};

// A helper class for writing distilled corpus shards. NOT thread-safe because
// all writes go to a single file.
class DistilledCorpusShardWriter : public CorpusShardWriter {
public:
DistilledCorpusShardWriter(const Environment &env, std::string_view mode)
: CorpusShardWriter{env, mode},
feature_set_(/*frequency_threshold=*/1, env.MakeDomainDiscardMask()) {}

~DistilledCorpusShardWriter() override = default;

const FeatureSet &feature_set() { return feature_set_; }

protected:
EltDisposition PreprocessElt(CorpusElt &elt) override {
feature_set_.PruneDiscardedDomains(elt.features);
if (!feature_set_.HasUnseenFeatures(elt.features))
return EltDisposition::kIgnore;
feature_set_.IncrementFrequencies(elt.features);
return EltDisposition::kWrite;
}

private:
FeatureSet feature_set_;
};

} // namespace

void DistillTask(const Environment &env,
const std::vector<size_t> &shard_indices) {
// Read and write the shards in parallel, but gate reading of each on the
// availability of free RAM to keep the peak RAM usage under control.
const size_t num_shards = shard_indices.size();
InputCorpusShardReader reader{env};
// NOTE: Always overwrite corpus and features files, never append.
DistilledCorpusShardWriter writer{env, "w"};

{
ThreadPool threads{kMaxReadingThreads};
for (size_t shard_idx : shard_indices) {
threads.Schedule([shard_idx, &reader, &writer, &env, num_shards] {
CorpusEltVec shard_elts = reader.ReadShard(shard_idx);
// Reverse the order of elements. The intuition is as follows:
// * If the shard is the result of fuzzing with Centipede, the inputs
// that are closer to the end are more interesting, so we start there.
// * If the shard resulted from somethening else, the reverse order is
// not any better or worse than any other order.
std::reverse(shard_elts.begin(), shard_elts.end());
{
absl::WriterMutexLock lock(&writer.Mutex());
writer.WriteBatch(std::move(shard_elts));
LOG(INFO) << LogPrefix(env) << writer.feature_set()
<< " src_shards: " << writer.num_written_batches() << "/"
<< num_shards << " src_elts: " << writer.num_total_elts()
<< " dst_elts: " << writer.num_written_elts();
}
});
}
} // The reading threads join here.
}

int Distill(const Environment &env) {
Expand All @@ -120,27 +250,32 @@ int Distill(const Environment &env) {
/*timelapse_interval=*/absl::Seconds(VLOG_IS_ON(2) ? 10 : 60), //
/*also_log_timelapses=*/VLOG_IS_ON(10));

// Run `env.num_threads` independent distillation threads.
std::vector<std::thread> threads(env.num_threads);
std::vector<Environment> envs(env.num_threads, env);
std::vector<Environment> envs_per_thread(env.num_threads, env);
std::vector<std::vector<size_t>> shard_indices_per_thread(env.num_threads);
// Start the threads.
// Prepare per-thread envs and input shard indices.
for (size_t thread_idx = 0; thread_idx < env.num_threads; ++thread_idx) {
envs[thread_idx].my_shard_index += thread_idx;
envs_per_thread[thread_idx].my_shard_index += thread_idx;
// Shuffle the shards, so that every thread produces different result.
Rng rng(GetRandomSeed(env.seed + thread_idx));
auto &shard_indices = shard_indices_per_thread[thread_idx];
shard_indices.resize(env.total_shards);
std::iota(shard_indices.begin(), shard_indices.end(), 0);
std::shuffle(shard_indices.begin(), shard_indices.end(), rng);
// Run the thread.
threads[thread_idx] =
std::thread(DistillTask, std::ref(envs[thread_idx]), shard_indices);
}
// Join threads.
for (size_t thread_idx = 0; thread_idx < env.num_threads; thread_idx++) {
threads[thread_idx].join();
}

// Run the distillation threads in parallel.
{
const size_t num_threads = std::min(env.num_threads, kMaxWritingThreads);
ThreadPool threads{static_cast<int>(num_threads)};
for (size_t thread_idx = 0; thread_idx < env.num_threads; ++thread_idx) {
threads.Schedule(
[&thread_env = envs_per_thread[thread_idx],
&thread_shard_indices = shard_indices_per_thread[thread_idx]]() {
DistillTask(thread_env, thread_shard_indices);
});
}
} // The writing threads join here.

return EXIT_SUCCESS;
}

Expand Down
10 changes: 10 additions & 0 deletions centipede/remote_file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <glob.h>

#include <cstdint>
#include <cstdio>
#include <filesystem> // NOLINT
#include <memory>
Expand Down Expand Up @@ -138,6 +139,15 @@ ABSL_ATTRIBUTE_WEAK bool RemotePathExists(std::string_view path) {
return std::filesystem::exists(path);
}

ABSL_ATTRIBUTE_WEAK int64_t RemoteFileGetSize(std::string_view path) {
FILE *f = std::fopen(path.data(), "r");
CHECK(f != nullptr) << VV(path);
std::fseek(f, 0, SEEK_END);
const auto sz = std::ftell(f);
std::fclose(f);
return sz;
}

namespace {

int HandleGlobError(const char *epath, int eerrno) {
Expand Down
Loading

0 comments on commit 564b831

Please sign in to comment.