Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[runtime] support MNN inference engine #310

Merged
merged 14 commits into from
Apr 23, 2024
34 changes: 34 additions & 0 deletions runtime/core/cmake/mnn.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
if(MNN)
set(MNN_URL "https://github.com/alibaba/MNN/archive/976d1d7c0f916ea8a7acc3d31352789590f00b18.zip")
set(URL_HASH "SHA256=7fcef0933992658e8725bdc1df2daff1410c8577c9c1ce838fd5d6c8c01d1ec1")

FetchContent_Declare(mnn
URL ${MNN_URL}
URL_HASH ${URL_HASH}
)

# 设置编译宏
set(MNN_BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(MNN_BUILD_TOOLS OFF CACHE BOOL "" FORCE)
set(MNN_SUPPORT_DEPRECATED_OP OFF CACHE BOOL "" FORCE)
set(MNN_SEP_BUILD ON CACHE BOOL "" FORCE)
set(MNN_BUILD_MINI ${MINI_LIBS} CACHE BOOL "" FORCE) # mini version
set(MNN_JNI OFF CACHE BOOL "" FORCE)
set(MNN_USE_CPP11 ON CACHE BOOL "" FORCE)
set(MNN_SUPPORT_BF16 OFF CACHE BOOL "" FORCE)
set(MNN_BUILD_OPENCV OFF CACHE BOOL "" FORCE)
set(MNN_LOW_MEMORY OFF CACHE BOOL "" FORCE)

FetchContent_GetProperties(mnn)
if(NOT mnn_POPULATED)
message(STATUS "Downloading mnn from ${MNN_URL}")
FetchContent_Populate(mnn)
endif()

message(STATUS "mnn is downloaded to ${mnn_SOURCE_DIR}")
message(STATUS "mnn's binary dir is ${mnn_BINARY_DIR}")
add_subdirectory(${mnn_SOURCE_DIR} ${mnn_BINARY_DIR})
include_directories(${mnn_SOURCE_DIR}/include)
link_directories(${mnn_BINARY_DIR})
add_definitions(-DUSE_MNN)
endif()
10 changes: 8 additions & 2 deletions runtime/core/speaker/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
set(speaker_srcs
speaker_engine.cc)

if(NOT ONNX AND NOT BPU)
message(FATAL_ERROR "Please build with ONNX or BPU!")
if(NOT ONNX AND NOT MNN)
message(FATAL_ERROR "Please build with ONNX or MNN!")
endif()
if(ONNX)
list(APPEND speaker_srcs onnx_speaker_model.cc)
endif()
if(MNN)
list(APPEND speaker_srcs mnn_speaker_model.cc)
endif()

add_library(speaker STATIC ${speaker_srcs})
target_link_libraries(speaker PUBLIC frontend)

if(ONNX)
target_link_libraries(speaker PUBLIC onnxruntime)
endif()
if(MNN)
target_link_libraries(speaker PUBLIC MNN)
endif()

102 changes: 102 additions & 0 deletions runtime/core/speaker/mnn_speaker_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) 2024 Chengdong Liang ([email protected])
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef USE_MNN

#include <vector>

#include "glog/logging.h"
#include "speaker/mnn_speaker_model.h"
#include "utils/utils.h"

namespace wespeaker {

MnnSpeakerModel::MnnSpeakerModel(const std::string& model_path,
int num_threads) {
// 1. Load sessions
speaker_interpreter_ = std::shared_ptr<MNN::Interpreter>(
MNN::Interpreter::createFromFile(model_path.c_str()));

MNN::ScheduleConfig config;
config.type = MNN_FORWARD_CPU;
config.numThread = num_threads;
MNN::BackendConfig backend_config;
backend_config.precision = MNN::BackendConfig::Precision_Low;
backend_config.power = MNN::BackendConfig::Power_High;
config.backendConfig = &backend_config;

speaker_session_ = speaker_interpreter_->createSession(config);
if (!speaker_session_) {
LOG(ERROR) << "[MNN] Create session failed!";
return;
}
}

MnnSpeakerModel::~MnnSpeakerModel() {
if (speaker_session_) {
speaker_interpreter_->releaseModel();
speaker_interpreter_->releaseSession(speaker_session_);
}
}

void MnnSpeakerModel::ExtractEmbedding(
const std::vector<std::vector<float>>& feats, std::vector<float>* embed) {
unsigned int num_frames = feats.size();
unsigned int feat_dim = feats[0].size();

// 1. input tensor
auto input_tensor =
speaker_interpreter_->getSessionInput(speaker_session_, nullptr);

auto shape = input_tensor->shape();
CHECK_EQ(shape.size(), 3);
if (shape[0] == -1 || shape[1] == -1 || shape[2] == -1) {
VLOG(2) << "dynamic shape.";
std::vector<int> input_dims = {1, static_cast<int>(num_frames),
static_cast<int>(feat_dim)};
speaker_interpreter_->resizeTensor(input_tensor, input_dims);
speaker_interpreter_->resizeSession(speaker_session_);
} else {
if (shape[0] != 1 || shape[1] != num_frames || shape[2] != feat_dim) {
LOG(ERROR) << "shape error!";
return;
}
}

std::shared_ptr<MNN::Tensor> nchw_tensor(
new MNN::Tensor(input_tensor, MNN::Tensor::CAFFE)); // NCHW
for (size_t i = 0; i < num_frames; ++i) {
for (size_t j = 0; j < feat_dim; ++j) {
nchw_tensor->host<float>()[i * feat_dim + j] = feats[i][j];
}
}
input_tensor->copyFromHostTensor(nchw_tensor.get());

// 2. run session
speaker_interpreter_->runSession(speaker_session_);

// 3. output
auto output = speaker_interpreter_->getSessionOutput(speaker_session_, NULL);
std::shared_ptr<MNN::Tensor> output_tensor(
new MNN::Tensor(output, MNN::Tensor::CAFFE));
output->copyToHostTensor(output_tensor.get());
embed->reserve(output_tensor->elementSize());
for (int i = 0; i < output_tensor->elementSize(); ++i) {
embed->push_back(output->host<float>()[i]);
}
}

} // namespace wespeaker

#endif // USE_MNN
49 changes: 49 additions & 0 deletions runtime/core/speaker/mnn_speaker_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2024 Chengdong Liang ([email protected])
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SPEAKER_MNN_SPEAKER_MODEL_H_
#define SPEAKER_MNN_SPEAKER_MODEL_H_

#ifdef USE_MNN

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "MNN/Interpreter.hpp"
#include "MNN/MNNDefine.h"
#include "MNN/Tensor.hpp"
#include "speaker/speaker_model.h"

namespace wespeaker {

class MnnSpeakerModel : public SpeakerModel {
public:
explicit MnnSpeakerModel(const std::string& model_path, int num_threads);

void ExtractEmbedding(const std::vector<std::vector<float>>& feats,
std::vector<float>* embed) override;
~MnnSpeakerModel();

private:
// session
std::shared_ptr<MNN::Interpreter> speaker_interpreter_;
MNN::Session* speaker_session_ = nullptr;
};

} // namespace wespeaker

#endif // USE_MNN
#endif // SPEAKER_MNN_SPEAKER_MODEL_H_
5 changes: 5 additions & 0 deletions runtime/core/speaker/speaker_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#ifdef USE_ONNX
#include "speaker/onnx_speaker_model.h"
#endif
#ifdef USE_MNN
#include "speaker/mnn_speaker_model.h"
#endif

namespace wespeaker {

Expand Down Expand Up @@ -48,6 +51,8 @@ SpeakerEngine::SpeakerEngine(const std::string& model_path, const int feat_dim,
OnnxSpeakerModel::SetGpuDeviceId(0);
#endif
model_ = std::make_shared<OnnxSpeakerModel>(model_path);
#elif USE_MNN
model_ = std::make_shared<MnnSpeakerModel>(model_path, kNumGemmThreads);
#elif USE_BPU
model_ = std::make_shared<BpuSpeakerModel>(model_path);
#endif
Expand Down
2 changes: 2 additions & 0 deletions runtime/mnn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
fc_base/
build*
30 changes: 30 additions & 0 deletions runtime/mnn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
cmake_minimum_required(VERSION 3.14)
project(wespeaker VERSION 0.1)

set(MNN ON CACHE BOOL "whether to build with MNN")
option(MINI_LIBS "whether to build minimum libraies with MNN" OFF)

set(CMAKE_VERBOSE_MAKEFILE OFF)

include(FetchContent)
set(FETCHCONTENT_QUIET OFF)
get_filename_component(fc_base "fc_base" REALPATH BASE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_base})

list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -pthread -fPIC")

# Include all dependency
if(MNN)
include(mnn)
endif()
include(glog)
include(gflags)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})

# build all libraries
add_subdirectory(utils)
add_subdirectory(frontend)
add_subdirectory(speaker)
add_subdirectory(bin)
96 changes: 96 additions & 0 deletions runtime/mnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# MNN backend on WeSpeaker

* Step 1. Export your experiment model to MNN

First, export your experiment model to ONNX by [export_onnx.py](../../wespeaker/bin/export_onnx.py).

``` sh
# 1. dynamic shape
python wespeaker/bin/export_onnx.py \
--config config.yaml \
--checkpoint model.pt \
--output_model model.onnx
# When it finishes, you can find `model.mnn`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.onnx ??

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修正

# 2. static shape
# python wespeaker/bin/export_onnx.py \
# --config config.yaml \
# --checkpoint model.pt \
# --output_model model.onnx \
# --num_frames 198
```

Second, export ONNX to MNN by [export_mnn.py](../../wespeaker/bin/export_mnn.py).

``` sh
# 1. dynamic shape
python wespeaker/bin/export_mnn.py \
--onnx_model model.onnx \
--output_model model.mnn
# When it finishes, you can find `model.mnn`.
# 2. static shape
# python wespeaker/bin/export_mnn.py \
# --onnx_model model.onnx \
# --output_model model.mnn \
# --num_frames 198
```

* Step 2. Build. The build requires cmake 3.14 or above, and gcc/g++ 5.4 or above.

``` sh
mkdir build && cd build
# 1. normal
cmake ..
# 2. minimum libs
# cmake -DMINI_LIBS=ON ..
cmake --build .
```

* Step 3. Testing.

1. the RTF(real time factor) is shown in the console, and embedding will be written to the txt file.
``` sh
export GLOG_logtostderr=1
export GLOG_v=2
wav_scp=your_test_wav_scp
mnn_dir=your_model_dir
embed_out=your_embedding_txt
./build/bin/extract_emb_main \
--wav_scp $wav_scp \
--result $embed_out \
--speaker_model_path $mnn_dir/final.mnn \
--embedding_size 256 \
--samples_per_chunk 80000 # 5s
```

> NOTE: samples_per_chunk: samples of one chunk. samples_per_chunk = sample_rate * duration
>
> If samples_per_chunk = -1, compute the embedding of whole sentence;
> else compute embedding with chunk by chunk, and then average embeddings of chunk.

2. Calculate the similarity of two speech.
```sh
export GLOG_logtostderr=1
export GLOG_v=2
mnn_dir=your_model_dir
./build/bin/asv_main \
--enroll_wav wav1_path \
--test_wav wav2_path \
--threshold 0.5 \
--speaker_model_path $onnx_dir/final.onnx \
--embedding_size 256
```

## Benchmark

1. RTF
> num_threads = 1
>
> samples_per_chunk = 3200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

32000 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修正

>
> Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz

| Model | Params | FLOPs | engine | RTF |
| :------------------ | :------ | :------- | :------------ | :------- |
| ResNet-34 | 6.63 M | 4.55 G | onnxruntime | 0.1377 |
| ResNet-34 | 6.63 M | 4.55 G | mnn | 0.1333 |
| ResNet-34 | 6.63 M | 4.55 G | mnn mini_libs | 0.05262 |
1 change: 1 addition & 0 deletions runtime/mnn/bin
1 change: 1 addition & 0 deletions runtime/mnn/cmake
1 change: 1 addition & 0 deletions runtime/mnn/frontend
1 change: 1 addition & 0 deletions runtime/mnn/speaker
1 change: 1 addition & 0 deletions runtime/mnn/utils
Loading
Loading