-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[runtime] support MNN inference engine (#310)
* [runtime] support MNN format * [runtime] support MNN format * [runtime] support MNN format * [runtime] add runtime code * [runtime] add MNN_BUILD_MINI * [runtime] add README.md * [runtime] refine code * [runtime] fix lint * [runtime] test rtf * [runtime] fix README.md * [runtime] fix typo * [runtime] fix typo * [runtime] fix typo
- Loading branch information
Showing
15 changed files
with
434 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
if(MNN) | ||
set(MNN_URL "https://github.com/alibaba/MNN/archive/976d1d7c0f916ea8a7acc3d31352789590f00b18.zip") | ||
set(URL_HASH "SHA256=7fcef0933992658e8725bdc1df2daff1410c8577c9c1ce838fd5d6c8c01d1ec1") | ||
|
||
FetchContent_Declare(mnn | ||
URL ${MNN_URL} | ||
URL_HASH ${URL_HASH} | ||
) | ||
|
||
set(MNN_BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) | ||
set(MNN_BUILD_TOOLS OFF CACHE BOOL "" FORCE) | ||
set(MNN_SUPPORT_DEPRECATED_OP OFF CACHE BOOL "" FORCE) | ||
set(MNN_SEP_BUILD ON CACHE BOOL "" FORCE) | ||
set(MNN_BUILD_MINI ${MINI_LIBS} CACHE BOOL "" FORCE) # mini version | ||
set(MNN_JNI OFF CACHE BOOL "" FORCE) | ||
set(MNN_USE_CPP11 ON CACHE BOOL "" FORCE) | ||
set(MNN_SUPPORT_BF16 OFF CACHE BOOL "" FORCE) | ||
set(MNN_BUILD_OPENCV OFF CACHE BOOL "" FORCE) | ||
set(MNN_LOW_MEMORY OFF CACHE BOOL "" FORCE) | ||
|
||
FetchContent_GetProperties(mnn) | ||
if(NOT mnn_POPULATED) | ||
message(STATUS "Downloading mnn from ${MNN_URL}") | ||
FetchContent_Populate(mnn) | ||
endif() | ||
|
||
message(STATUS "mnn is downloaded to ${mnn_SOURCE_DIR}") | ||
message(STATUS "mnn's binary dir is ${mnn_BINARY_DIR}") | ||
add_subdirectory(${mnn_SOURCE_DIR} ${mnn_BINARY_DIR}) | ||
include_directories(${mnn_SOURCE_DIR}/include) | ||
link_directories(${mnn_BINARY_DIR}) | ||
add_definitions(-DUSE_MNN) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,23 @@ | ||
set(speaker_srcs | ||
speaker_engine.cc) | ||
|
||
if(NOT ONNX AND NOT BPU) | ||
message(FATAL_ERROR "Please build with ONNX or BPU!") | ||
if(NOT ONNX AND NOT MNN) | ||
message(FATAL_ERROR "Please build with ONNX or MNN!") | ||
endif() | ||
if(ONNX) | ||
list(APPEND speaker_srcs onnx_speaker_model.cc) | ||
endif() | ||
if(MNN) | ||
list(APPEND speaker_srcs mnn_speaker_model.cc) | ||
endif() | ||
|
||
add_library(speaker STATIC ${speaker_srcs}) | ||
target_link_libraries(speaker PUBLIC frontend) | ||
|
||
if(ONNX) | ||
target_link_libraries(speaker PUBLIC onnxruntime) | ||
endif() | ||
if(MNN) | ||
target_link_libraries(speaker PUBLIC MNN) | ||
endif() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// Copyright (c) 2024 Chengdong Liang ([email protected]) | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifdef USE_MNN | ||
|
||
#include <vector> | ||
|
||
#include "glog/logging.h" | ||
#include "speaker/mnn_speaker_model.h" | ||
#include "utils/utils.h" | ||
|
||
namespace wespeaker { | ||
|
||
MnnSpeakerModel::MnnSpeakerModel(const std::string& model_path, | ||
int num_threads) { | ||
// 1. Load sessions | ||
speaker_interpreter_ = std::shared_ptr<MNN::Interpreter>( | ||
MNN::Interpreter::createFromFile(model_path.c_str())); | ||
|
||
MNN::ScheduleConfig config; | ||
config.type = MNN_FORWARD_CPU; | ||
config.numThread = num_threads; | ||
MNN::BackendConfig backend_config; | ||
backend_config.precision = MNN::BackendConfig::Precision_Low; | ||
backend_config.power = MNN::BackendConfig::Power_High; | ||
config.backendConfig = &backend_config; | ||
|
||
speaker_session_ = speaker_interpreter_->createSession(config); | ||
if (!speaker_session_) { | ||
LOG(ERROR) << "[MNN] Create session failed!"; | ||
return; | ||
} | ||
} | ||
|
||
MnnSpeakerModel::~MnnSpeakerModel() { | ||
if (speaker_session_) { | ||
speaker_interpreter_->releaseModel(); | ||
speaker_interpreter_->releaseSession(speaker_session_); | ||
} | ||
} | ||
|
||
void MnnSpeakerModel::ExtractEmbedding( | ||
const std::vector<std::vector<float>>& feats, std::vector<float>* embed) { | ||
unsigned int num_frames = feats.size(); | ||
unsigned int feat_dim = feats[0].size(); | ||
|
||
// 1. input tensor | ||
auto input_tensor = | ||
speaker_interpreter_->getSessionInput(speaker_session_, nullptr); | ||
|
||
auto shape = input_tensor->shape(); | ||
CHECK_EQ(shape.size(), 3); | ||
if (shape[0] == -1 || shape[1] == -1 || shape[2] == -1) { | ||
VLOG(2) << "dynamic shape."; | ||
std::vector<int> input_dims = {1, static_cast<int>(num_frames), | ||
static_cast<int>(feat_dim)}; | ||
speaker_interpreter_->resizeTensor(input_tensor, input_dims); | ||
speaker_interpreter_->resizeSession(speaker_session_); | ||
} else { | ||
if (shape[0] != 1 || shape[1] != num_frames || shape[2] != feat_dim) { | ||
LOG(ERROR) << "shape error!"; | ||
return; | ||
} | ||
} | ||
|
||
std::shared_ptr<MNN::Tensor> nchw_tensor( | ||
new MNN::Tensor(input_tensor, MNN::Tensor::CAFFE)); // NCHW | ||
for (size_t i = 0; i < num_frames; ++i) { | ||
for (size_t j = 0; j < feat_dim; ++j) { | ||
nchw_tensor->host<float>()[i * feat_dim + j] = feats[i][j]; | ||
} | ||
} | ||
input_tensor->copyFromHostTensor(nchw_tensor.get()); | ||
|
||
// 2. run session | ||
speaker_interpreter_->runSession(speaker_session_); | ||
|
||
// 3. output | ||
auto output = speaker_interpreter_->getSessionOutput(speaker_session_, NULL); | ||
std::shared_ptr<MNN::Tensor> output_tensor( | ||
new MNN::Tensor(output, MNN::Tensor::CAFFE)); | ||
output->copyToHostTensor(output_tensor.get()); | ||
embed->reserve(output_tensor->elementSize()); | ||
for (int i = 0; i < output_tensor->elementSize(); ++i) { | ||
embed->push_back(output->host<float>()[i]); | ||
} | ||
} | ||
|
||
} // namespace wespeaker | ||
|
||
#endif // USE_MNN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// Copyright (c) 2024 Chengdong Liang ([email protected]) | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef SPEAKER_MNN_SPEAKER_MODEL_H_ | ||
#define SPEAKER_MNN_SPEAKER_MODEL_H_ | ||
|
||
#ifdef USE_MNN | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "MNN/Interpreter.hpp" | ||
#include "MNN/MNNDefine.h" | ||
#include "MNN/Tensor.hpp" | ||
#include "speaker/speaker_model.h" | ||
|
||
namespace wespeaker { | ||
|
||
class MnnSpeakerModel : public SpeakerModel { | ||
public: | ||
explicit MnnSpeakerModel(const std::string& model_path, int num_threads); | ||
|
||
void ExtractEmbedding(const std::vector<std::vector<float>>& feats, | ||
std::vector<float>* embed) override; | ||
~MnnSpeakerModel(); | ||
|
||
private: | ||
// session | ||
std::shared_ptr<MNN::Interpreter> speaker_interpreter_; | ||
MNN::Session* speaker_session_ = nullptr; | ||
}; | ||
|
||
} // namespace wespeaker | ||
|
||
#endif // USE_MNN | ||
#endif // SPEAKER_MNN_SPEAKER_MODEL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
fc_base/ | ||
build* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
cmake_minimum_required(VERSION 3.14) | ||
project(wespeaker VERSION 0.1) | ||
|
||
set(MNN ON CACHE BOOL "whether to build with MNN") | ||
option(MINI_LIBS "whether to build minimum libraies with MNN" OFF) | ||
|
||
set(CMAKE_VERBOSE_MAKEFILE OFF) | ||
|
||
include(FetchContent) | ||
set(FETCHCONTENT_QUIET OFF) | ||
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") | ||
set(FETCHCONTENT_BASE_DIR ${fc_base}) | ||
|
||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC") | ||
|
||
# Include all dependency | ||
if(MNN) | ||
include(mnn) | ||
endif() | ||
include(glog) | ||
include(gflags) | ||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) | ||
|
||
# build all libraries | ||
add_subdirectory(utils) | ||
add_subdirectory(frontend) | ||
add_subdirectory(speaker) | ||
add_subdirectory(bin) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# MNN backend on WeSpeaker | ||
|
||
* Step 1. Export your experiment model to MNN | ||
|
||
First, export your experiment model to ONNX by [export_onnx.py](../../wespeaker/bin/export_onnx.py). | ||
|
||
``` sh | ||
# 1. dynamic shape | ||
python wespeaker/bin/export_onnx.py \ | ||
--config config.yaml \ | ||
--checkpoint model.pt \ | ||
--output_model model.onnx | ||
# When it finishes, you can find `model.onnx`. | ||
# 2. static shape | ||
# python wespeaker/bin/export_onnx.py \ | ||
# --config config.yaml \ | ||
# --checkpoint model.pt \ | ||
# --output_model model.onnx \ | ||
# --num_frames 198 | ||
``` | ||
|
||
Second, export ONNX to MNN by [export_mnn.py](../../wespeaker/bin/export_mnn.py). | ||
|
||
``` sh | ||
# 1. dynamic shape | ||
python wespeaker/bin/export_mnn.py \ | ||
--onnx_model model.onnx \ | ||
--output_model model.mnn | ||
# When it finishes, you can find `model.mnn`. | ||
# 2. static shape | ||
# python wespeaker/bin/export_mnn.py \ | ||
# --onnx_model model.onnx \ | ||
# --output_model model.mnn \ | ||
# --num_frames 198 | ||
``` | ||
|
||
* Step 2. Build. The build requires cmake 3.14 or above, and gcc/g++ 5.4 or above. | ||
|
||
``` sh | ||
mkdir build && cd build | ||
# 1. normal | ||
cmake .. | ||
# 2. minimum libs | ||
# cmake -DMINI_LIBS=ON .. | ||
cmake --build . | ||
``` | ||
|
||
* Step 3. Testing. | ||
|
||
1. the RTF(real time factor) is shown in the console, and embedding will be written to the txt file. | ||
``` sh | ||
export GLOG_logtostderr=1 | ||
export GLOG_v=2 | ||
./build/bin/extract_emb_main \ | ||
--wav_scp wav.scp \ | ||
--result embedding.txt \ | ||
--speaker_model_path model.mnn \ | ||
--embedding_size 256 \ | ||
--samples_per_chunk 80000 # 5s | ||
``` | ||
|
||
> NOTE: samples_per_chunk: samples of one chunk. samples_per_chunk = sample_rate * duration | ||
> | ||
> If samples_per_chunk = -1, compute the embedding of whole sentence; | ||
> else compute embedding with chunk by chunk, and then average embeddings of chunk. | ||
2. Calculate the similarity of two speech. | ||
```sh | ||
export GLOG_logtostderr=1 | ||
export GLOG_v=2 | ||
./build/bin/asv_main \ | ||
--enroll_wav wav1_path \ | ||
--test_wav wav2_path \ | ||
--threshold 0.5 \ | ||
--speaker_model_path model.mnn \ | ||
--embedding_size 256 | ||
``` | ||
|
||
## Benchmark | ||
|
||
1. RTF | ||
> num_threads = 1 | ||
> | ||
> samples_per_chunk = 32000 | ||
> | ||
> Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz | ||
| Model | Params | FLOPs | engine | RTF | | ||
| :------------------ | :------ | :------- | :------------ | :------- | | ||
| ResNet-34 | 6.63 M | 4.55 G | onnxruntime | 0.1377 | | ||
| ResNet-34 | 6.63 M | 4.55 G | mnn | 0.1333 | | ||
| ResNet-34 | 6.63 M | 4.55 G | mnn mini_libs | 0.05262 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/bin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/cmake |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/frontend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/speaker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/utils |
Oops, something went wrong.