-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[runtime] support MNN inference engine #310
Merged
Changes from 10 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
8f166ff
[runtime] support MNN format
cdliang11 2c6154f
[runtime] support MNN format
cdliang11 3b0d85b
[runtime] support MNN format
cdliang11 b712b2e
Merge remote-tracking branch 'origin/master' into mnn
cdliang11 a857dd5
[runtime] add runtime code
cdliang11 7bfa55d
[runtime] add MNN_BUILD_MINI
cdliang11 2f556ad
[runtime] add README.md
cdliang11 1fc5b19
[runtime] refine code
cdliang11 02112bf
[runtime] fix lint
cdliang11 66706f0
[runtime] test rtf
cdliang11 1150afb
[runtime] fix README.md
cdliang11 c4e608c
[runtime] fix typo
cdliang11 054f644
[runtime] fix typo
cdliang11 017916f
[runtime] fix typo
cdliang11 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
if(MNN) | ||
set(MNN_URL "https://github.com/alibaba/MNN/archive/976d1d7c0f916ea8a7acc3d31352789590f00b18.zip") | ||
set(URL_HASH "SHA256=7fcef0933992658e8725bdc1df2daff1410c8577c9c1ce838fd5d6c8c01d1ec1") | ||
|
||
FetchContent_Declare(mnn | ||
URL ${MNN_URL} | ||
URL_HASH ${URL_HASH} | ||
) | ||
|
||
# 设置编译宏 | ||
set(MNN_BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) | ||
set(MNN_BUILD_TOOLS OFF CACHE BOOL "" FORCE) | ||
set(MNN_SUPPORT_DEPRECATED_OP OFF CACHE BOOL "" FORCE) | ||
set(MNN_SEP_BUILD ON CACHE BOOL "" FORCE) | ||
set(MNN_BUILD_MINI ${MINI_LIBS} CACHE BOOL "" FORCE) # mini version | ||
set(MNN_JNI OFF CACHE BOOL "" FORCE) | ||
set(MNN_USE_CPP11 ON CACHE BOOL "" FORCE) | ||
set(MNN_SUPPORT_BF16 OFF CACHE BOOL "" FORCE) | ||
set(MNN_BUILD_OPENCV OFF CACHE BOOL "" FORCE) | ||
set(MNN_LOW_MEMORY OFF CACHE BOOL "" FORCE) | ||
|
||
FetchContent_GetProperties(mnn) | ||
if(NOT mnn_POPULATED) | ||
message(STATUS "Downloading mnn from ${MNN_URL}") | ||
FetchContent_Populate(mnn) | ||
endif() | ||
|
||
message(STATUS "mnn is downloaded to ${mnn_SOURCE_DIR}") | ||
message(STATUS "mnn's binary dir is ${mnn_BINARY_DIR}") | ||
add_subdirectory(${mnn_SOURCE_DIR} ${mnn_BINARY_DIR}) | ||
include_directories(${mnn_SOURCE_DIR}/include) | ||
link_directories(${mnn_BINARY_DIR}) | ||
add_definitions(-DUSE_MNN) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,23 @@ | ||
set(speaker_srcs | ||
speaker_engine.cc) | ||
|
||
if(NOT ONNX AND NOT BPU) | ||
message(FATAL_ERROR "Please build with ONNX or BPU!") | ||
if(NOT ONNX AND NOT MNN) | ||
message(FATAL_ERROR "Please build with ONNX or MNN!") | ||
endif() | ||
if(ONNX) | ||
list(APPEND speaker_srcs onnx_speaker_model.cc) | ||
endif() | ||
if(MNN) | ||
list(APPEND speaker_srcs mnn_speaker_model.cc) | ||
endif() | ||
|
||
add_library(speaker STATIC ${speaker_srcs}) | ||
target_link_libraries(speaker PUBLIC frontend) | ||
|
||
if(ONNX) | ||
target_link_libraries(speaker PUBLIC onnxruntime) | ||
endif() | ||
if(MNN) | ||
target_link_libraries(speaker PUBLIC MNN) | ||
endif() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
// Copyright (c) 2024 Chengdong Liang ([email protected]) | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifdef USE_MNN | ||
|
||
#include <vector> | ||
|
||
#include "glog/logging.h" | ||
#include "speaker/mnn_speaker_model.h" | ||
#include "utils/utils.h" | ||
|
||
namespace wespeaker { | ||
|
||
MnnSpeakerModel::MnnSpeakerModel(const std::string& model_path, | ||
int num_threads) { | ||
// 1. Load sessions | ||
speaker_interpreter_ = std::shared_ptr<MNN::Interpreter>( | ||
MNN::Interpreter::createFromFile(model_path.c_str())); | ||
|
||
MNN::ScheduleConfig config; | ||
config.type = MNN_FORWARD_CPU; | ||
config.numThread = num_threads; | ||
MNN::BackendConfig backend_config; | ||
backend_config.precision = MNN::BackendConfig::Precision_Low; | ||
backend_config.power = MNN::BackendConfig::Power_High; | ||
config.backendConfig = &backend_config; | ||
|
||
speaker_session_ = speaker_interpreter_->createSession(config); | ||
if (!speaker_session_) { | ||
LOG(ERROR) << "[MNN] Create session failed!"; | ||
return; | ||
} | ||
} | ||
|
||
MnnSpeakerModel::~MnnSpeakerModel() { | ||
if (speaker_session_) { | ||
speaker_interpreter_->releaseModel(); | ||
speaker_interpreter_->releaseSession(speaker_session_); | ||
} | ||
} | ||
|
||
void MnnSpeakerModel::ExtractEmbedding( | ||
const std::vector<std::vector<float>>& feats, std::vector<float>* embed) { | ||
unsigned int num_frames = feats.size(); | ||
unsigned int feat_dim = feats[0].size(); | ||
|
||
// 1. input tensor | ||
auto input_tensor = | ||
speaker_interpreter_->getSessionInput(speaker_session_, nullptr); | ||
|
||
auto shape = input_tensor->shape(); | ||
CHECK_EQ(shape.size(), 3); | ||
if (shape[0] == -1 || shape[1] == -1 || shape[2] == -1) { | ||
VLOG(2) << "dynamic shape."; | ||
std::vector<int> input_dims = {1, static_cast<int>(num_frames), | ||
static_cast<int>(feat_dim)}; | ||
speaker_interpreter_->resizeTensor(input_tensor, input_dims); | ||
speaker_interpreter_->resizeSession(speaker_session_); | ||
} else { | ||
if (shape[0] != 1 || shape[1] != num_frames || shape[2] != feat_dim) { | ||
LOG(ERROR) << "shape error!"; | ||
return; | ||
} | ||
} | ||
|
||
std::shared_ptr<MNN::Tensor> nchw_tensor( | ||
new MNN::Tensor(input_tensor, MNN::Tensor::CAFFE)); // NCHW | ||
for (size_t i = 0; i < num_frames; ++i) { | ||
for (size_t j = 0; j < feat_dim; ++j) { | ||
nchw_tensor->host<float>()[i * feat_dim + j] = feats[i][j]; | ||
} | ||
} | ||
input_tensor->copyFromHostTensor(nchw_tensor.get()); | ||
|
||
// 2. run session | ||
speaker_interpreter_->runSession(speaker_session_); | ||
|
||
// 3. output | ||
auto output = speaker_interpreter_->getSessionOutput(speaker_session_, NULL); | ||
std::shared_ptr<MNN::Tensor> output_tensor( | ||
new MNN::Tensor(output, MNN::Tensor::CAFFE)); | ||
output->copyToHostTensor(output_tensor.get()); | ||
embed->reserve(output_tensor->elementSize()); | ||
for (int i = 0; i < output_tensor->elementSize(); ++i) { | ||
embed->push_back(output->host<float>()[i]); | ||
} | ||
} | ||
|
||
} // namespace wespeaker | ||
|
||
#endif // USE_MNN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// Copyright (c) 2024 Chengdong Liang ([email protected]) | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef SPEAKER_MNN_SPEAKER_MODEL_H_ | ||
#define SPEAKER_MNN_SPEAKER_MODEL_H_ | ||
|
||
#ifdef USE_MNN | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "MNN/Interpreter.hpp" | ||
#include "MNN/MNNDefine.h" | ||
#include "MNN/Tensor.hpp" | ||
#include "speaker/speaker_model.h" | ||
|
||
namespace wespeaker { | ||
|
||
class MnnSpeakerModel : public SpeakerModel { | ||
public: | ||
explicit MnnSpeakerModel(const std::string& model_path, int num_threads); | ||
|
||
void ExtractEmbedding(const std::vector<std::vector<float>>& feats, | ||
std::vector<float>* embed) override; | ||
~MnnSpeakerModel(); | ||
|
||
private: | ||
// session | ||
std::shared_ptr<MNN::Interpreter> speaker_interpreter_; | ||
MNN::Session* speaker_session_ = nullptr; | ||
}; | ||
|
||
} // namespace wespeaker | ||
|
||
#endif // USE_MNN | ||
#endif // SPEAKER_MNN_SPEAKER_MODEL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
fc_base/ | ||
build* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
cmake_minimum_required(VERSION 3.14) | ||
project(wespeaker VERSION 0.1) | ||
|
||
set(MNN ON CACHE BOOL "whether to build with MNN") | ||
option(MINI_LIBS "whether to build minimum libraies with MNN" OFF) | ||
|
||
set(CMAKE_VERBOSE_MAKEFILE OFF) | ||
|
||
include(FetchContent) | ||
set(FETCHCONTENT_QUIET OFF) | ||
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") | ||
set(FETCHCONTENT_BASE_DIR ${fc_base}) | ||
|
||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) | ||
|
||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC") | ||
|
||
# Include all dependency | ||
if(MNN) | ||
include(mnn) | ||
endif() | ||
include(glog) | ||
include(gflags) | ||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) | ||
|
||
# build all libraries | ||
add_subdirectory(utils) | ||
add_subdirectory(frontend) | ||
add_subdirectory(speaker) | ||
add_subdirectory(bin) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# MNN backend on WeSpeaker | ||
|
||
* Step 1. Export your experiment model to MNN | ||
|
||
First, export your experiment model to ONNX by [export_onnx.py](../../wespeaker/bin/export_onnx.py). | ||
|
||
``` sh | ||
# 1. dynamic shape | ||
python wespeaker/bin/export_onnx.py \ | ||
--config config.yaml \ | ||
--checkpoint model.pt \ | ||
--output_model model.onnx | ||
# When it finishes, you can find `model.mnn`. | ||
# 2. static shape | ||
# python wespeaker/bin/export_onnx.py \ | ||
# --config config.yaml \ | ||
# --checkpoint model.pt \ | ||
# --output_model model.onnx \ | ||
# --num_frames 198 | ||
``` | ||
|
||
Second, export ONNX to MNN by [export_mnn.py](../../wespeaker/bin/export_mnn.py). | ||
|
||
``` sh | ||
# 1. dynamic shape | ||
python wespeaker/bin/export_mnn.py \ | ||
--onnx_model model.onnx \ | ||
--output_model model.mnn | ||
# When it finishes, you can find `model.mnn`. | ||
# 2. static shape | ||
# python wespeaker/bin/export_mnn.py \ | ||
# --onnx_model model.onnx \ | ||
# --output_model model.mnn \ | ||
# --num_frames 198 | ||
``` | ||
|
||
* Step 2. Build. The build requires cmake 3.14 or above, and gcc/g++ 5.4 or above. | ||
|
||
``` sh | ||
mkdir build && cd build | ||
# 1. normal | ||
cmake .. | ||
# 2. minimum libs | ||
# cmake -DMINI_LIBS=ON .. | ||
cmake --build . | ||
``` | ||
|
||
* Step 3. Testing. | ||
|
||
1. the RTF(real time factor) is shown in the console, and embedding will be written to the txt file. | ||
``` sh | ||
export GLOG_logtostderr=1 | ||
export GLOG_v=2 | ||
wav_scp=your_test_wav_scp | ||
mnn_dir=your_model_dir | ||
embed_out=your_embedding_txt | ||
./build/bin/extract_emb_main \ | ||
--wav_scp $wav_scp \ | ||
--result $embed_out \ | ||
--speaker_model_path $mnn_dir/final.mnn \ | ||
--embedding_size 256 \ | ||
--samples_per_chunk 80000 # 5s | ||
``` | ||
|
||
> NOTE: samples_per_chunk: samples of one chunk. samples_per_chunk = sample_rate * duration | ||
> | ||
> If samples_per_chunk = -1, compute the embedding of whole sentence; | ||
> else compute embedding with chunk by chunk, and then average embeddings of chunk. | ||
|
||
2. Calculate the similarity of two speech. | ||
```sh | ||
export GLOG_logtostderr=1 | ||
export GLOG_v=2 | ||
mnn_dir=your_model_dir | ||
./build/bin/asv_main \ | ||
--enroll_wav wav1_path \ | ||
--test_wav wav2_path \ | ||
--threshold 0.5 \ | ||
--speaker_model_path $onnx_dir/final.onnx \ | ||
--embedding_size 256 | ||
``` | ||
|
||
## Benchmark | ||
|
||
1. RTF | ||
> num_threads = 1 | ||
> | ||
> samples_per_chunk = 3200 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 32000 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修正 |
||
> | ||
> Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz | ||
|
||
| Model | Params | FLOPs | engine | RTF | | ||
| :------------------ | :------ | :------- | :------------ | :------- | | ||
| ResNet-34 | 6.63 M | 4.55 G | onnxruntime | 0.1377 | | ||
| ResNet-34 | 6.63 M | 4.55 G | mnn | 0.1333 | | ||
| ResNet-34 | 6.63 M | 4.55 G | mnn mini_libs | 0.05262 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/bin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/cmake |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/frontend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/speaker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../core/utils |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.onnx ??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修正