Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime] support multi-threaded extracting embedding #262

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 67 additions & 45 deletions runtime/core/bin/extract_emb_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,77 +19,99 @@

#include "frontend/wav.h"
#include "speaker/speaker_engine.h"
#include "utils/thread_pool.h"
#include "utils/timer.h"
#include "utils/utils.h"

DEFINE_string(wav_list, "", "input wav scp");
DEFINE_string(wav_scp, "", "input wav scp");
DEFINE_string(wav_path, "", "input wav path");
DEFINE_string(result, "", "output embedding file");

DEFINE_string(speaker_model_path, "", "path of speaker model");
DEFINE_int32(fbank_dim, 80, "fbank feature dimension");
DEFINE_int32(sample_rate, 16000, "sample rate");
DEFINE_int32(embedding_size, 256, "embedding size");
DEFINE_int32(samples_per_chunk, 32000, "samples of one chunk");
DEFINE_int32(thread_num, 1, "num of extract_emb thread");

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
std::ofstream g_result;
std::mutex g_result_mutex;
int g_total_waves_dur = 0;
int g_total_extract_time = 0;

void extract_emb(std::pair<std::string, std::string> wav) {
// init model
LOG(INFO) << "Init model ...";
auto speaker_engine = std::make_shared<wespeaker::SpeakerEngine>(
FLAGS_speaker_model_path, FLAGS_fbank_dim, FLAGS_sample_rate,
FLAGS_embedding_size, FLAGS_samples_per_chunk);
int embedding_size = speaker_engine->EmbeddingSize();
LOG(INFO) << "embedding size: " << embedding_size;
// read wav.scp
// [utt, wav_path]
wenet::WavReader wav_reader(wav.second);
CHECK_EQ(wav_reader.sample_rate(), 16000);
int16_t* data = const_cast<int16_t*>(wav_reader.data());
int samples = wav_reader.num_sample();
// NOTE(cdliang): memory allocation
std::vector<float> embs(FLAGS_embedding_size, 0);

int wave_dur = static_cast<int>(static_cast<float>(samples) /
wav_reader.sample_rate() * 1000);
int extract_time = 0;
wenet::Timer timer;
speaker_engine->ExtractEmbedding(data, samples, &embs);
extract_time = timer.Elapsed();
LOG(INFO) << "process: " << wav.first
<< " RTF: " << static_cast<float>(extract_time) / wave_dur;
g_result_mutex.lock();
std::ostream& buffer = FLAGS_result.empty() ? std::cout : g_result;
buffer << wav.first;
for (size_t i = 0; i < embs.size(); i++) {
buffer << " " << embs[i];
}
buffer << std::endl;
g_total_waves_dur += wave_dur;
g_total_extract_time += extract_time;
g_result_mutex.unlock();
}

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);

if (FLAGS_wav_scp.empty() && FLAGS_wav_path.empty()) {
LOG(FATAL) << "wav_scp and wav_path should not be empty at the same time";
}

std::vector<std::pair<std::string, std::string>> waves;
std::ifstream wav_scp(FLAGS_wav_list);
std::string line;
while (getline(wav_scp, line)) {
std::vector<std::string> strs;
wespeaker::SplitString(line, &strs);
CHECK_EQ(strs.size(), 2);
waves.emplace_back(make_pair(strs[0], strs[1]));
if (!FLAGS_wav_path.empty()) {
waves.emplace_back(make_pair("test", FLAGS_wav_path));
} else {
std::ifstream wav_scp(FLAGS_wav_scp);
std::string line;
while (getline(wav_scp, line)) {
std::vector<std::string> strs;
wespeaker::SplitString(line, &strs);
CHECK_EQ(strs.size(), 2);
waves.emplace_back(make_pair(strs[0], strs[1]));
}
if (waves.empty()) {
LOG(FATAL) << "Please provide non-empty wav scp.";
}
}

std::ofstream result;
if (!FLAGS_result.empty()) {
result.open(FLAGS_result, std::ios::out);
g_result.open(FLAGS_result, std::ios::out);
}
std::ostream& buffer = FLAGS_result.empty() ? std::cout : result;

int total_waves_dur = 0;
int total_extract_time = 0;
for (auto& wav : waves) {
auto data_reader = wenet::ReadAudioFile(wav.second);
CHECK_EQ(data_reader->sample_rate(), 16000);
int16_t* data = const_cast<int16_t*>(data_reader->data());
int samples = data_reader->num_sample();
// NOTE(cdliang): memory allocation
std::vector<float> embs(embedding_size, 0);
buffer << wav.first;

int wave_dur = static_cast<int>(static_cast<float>(samples) /
data_reader->sample_rate() * 1000);
int extract_time = 0;
wenet::Timer timer;
speaker_engine->ExtractEmbedding(data, samples, &embs);
extract_time = timer.Elapsed();
for (size_t i = 0; i < embs.size(); i++) {
buffer << " " << embs[i];
{
ThreadPool pool(std::min(FLAGS_thread_num, static_cast<int>(waves.size())));
for (auto& wav : waves) {
pool.enqueue(extract_emb, wav);
}
buffer << std::endl;
LOG(INFO) << "process: " << wav.first
<< " RTF: " << static_cast<float>(extract_time) / wave_dur;
total_waves_dur += wave_dur;
total_extract_time += extract_time;
}
result.close();
LOG(INFO) << "Total: process " << total_waves_dur << "ms audio taken "
<< total_extract_time << "ms.";
LOG(INFO) << "RTF: "
<< static_cast<float>(total_extract_time) / total_waves_dur;
LOG(INFO) << "Total: process " << g_total_waves_dur << "ms audio taken "
<< g_total_extract_time << "ms.";
LOG(INFO) << "RTF: " << std::setprecision(4)
<< static_cast<float>(g_total_extract_time) / g_total_waves_dur;
return 0;
}
113 changes: 113 additions & 0 deletions runtime/core/utils/thread_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) 2012 Jakob Progsch, Václav Zeman

// This software is provided 'as-is', without any express or implied
// warranty. In no event will the authors be held liable for any damages
// arising from the use of this software.

// Permission is granted to anyone to use this software for any purpose,
// including commercial applications, and to alter it and redistribute it
// freely, subject to the following restrictions:

// 1. The origin of this software must not be misrepresented; you must not
// claim that you wrote the original software. If you use this software
// in a product, an acknowledgment in the product documentation would be
// appreciated but is not required.

// 2. Altered source versions must be plainly marked as such, and must not be
// misrepresented as being the original software.

// 3. This notice may not be removed or altered from any source
// distribution.

#ifndef UTILS_THREAD_POOL_H_
#define UTILS_THREAD_POOL_H_

#include <condition_variable>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <stdexcept>
#include <thread>
#include <utility>
#include <vector>

class ThreadPool {
public:
explicit ThreadPool(size_t);
template <class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();

private:
// need to keep track of threads so we can join them
std::vector<std::thread> workers;
// the task queue
std::queue<std::function<void()> > tasks;

// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};

// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads) : stop(false) {
for (size_t i = 0; i < threads; ++i)
workers.emplace_back([this] {
for (;;) {
std::function<void()> task;

{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(
lock, [this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty()) return;
task = std::move(this->tasks.front());
this->tasks.pop();
}

task();
}
});
}

// add new work item to the pool
template <class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type> {
using return_type = typename std::result_of<F(Args...)>::type;

auto task = std::make_shared<std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...));

std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);

// don't allow enqueueing after stopping the pool
if (stop) {
throw std::runtime_error("enqueue on stopped ThreadPool");
}

tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}

// the destructor joins all threads
inline ThreadPool::~ThreadPool() {
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers) {
worker.join();
}
}

#endif // UTILS_THREAD_POOL_H_
2 changes: 1 addition & 1 deletion runtime/horizonbpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ export GLOG_v=2
wav_scp=your_test_wav_scp
embed_out=your_embedding_txt
./build/bin/extract_emb_main \
--wav_list $wav_scp \
--wav_scp $wav_scp \
--result $embed_out \
--speaker_model_path speaker.bin \
--embedding_size 256 \
Expand Down
2 changes: 1 addition & 1 deletion runtime/onnxruntime/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ wav_scp=your_test_wav_scp
onnx_dir=your_model_dir
embed_out=your_embedding_txt
./build/bin/extract_emb_main \
--wav_list $wav_scp \
--wav_scp $wav_scp \
--result $embed_out \
--speaker_model_path $onnx_dir/final.onnx \
--embedding_size 256 \
Expand Down
Loading