Skip to content

Commit

Permalink
Enable QNN weight sharing (#21077)
Browse files Browse the repository at this point in the history
### Description
Enable QNN weight sharing across graphs in single context
Create tool to generate QNN context cache model with weight sharing enabled.
  • Loading branch information
HectorSVC authored Sep 4, 2024
1 parent 9031112 commit 190588b
Show file tree
Hide file tree
Showing 18 changed files with 1,078 additions and 50 deletions.
31 changes: 31 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,37 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
endif()
endif()
endif()


if(onnxruntime_USE_QNN)
#qnn ctx generator
set(onnxruntime_qnn_ctx_gen_src_dir ${TEST_SRC_DIR}/qnn_ctx_gen)
set(onnxruntime_qnn_ctx_gen_src_patterns
"${onnxruntime_qnn_ctx_gen_src_dir}/*.cc"
"${onnxruntime_qnn_ctx_gen_src_dir}/*.h")

file(GLOB onnxruntime_qnn_ctx_gen_src CONFIGURE_DEPENDS
${onnxruntime_qnn_ctx_gen_src_patterns}
)
onnxruntime_add_executable(onnxruntime_qnn_ctx_gen ${onnxruntime_qnn_ctx_gen_src})
target_include_directories(onnxruntime_qnn_ctx_gen PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT}
${eigen_INCLUDE_DIRS} ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir}
${CMAKE_CURRENT_BINARY_DIR})
if (WIN32)
target_compile_options(onnxruntime_qnn_ctx_gen PRIVATE ${disabled_warnings})
if (NOT DEFINED SYS_PATH_LIB)
set(SYS_PATH_LIB shlwapi)
endif()
endif()

if(WIN32)
target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE debug dbghelp advapi32)
endif()
target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})

set_target_properties(onnxruntime_qnn_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest")
endif()

# shared lib
if (onnxruntime_BUILD_SHARED_LIB)
onnxruntime_add_static_library(onnxruntime_mocked_allocator ${TEST_SRC_DIR}/util/test_allocator.cc)
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3654,6 +3654,9 @@ struct OrtApi {
Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
- "0": Default. With fp32 precision.
- "1": With fp16 precision.
"enable_htp_weight_sharing": Enable QNN weight sharing feature while compiling multiple graphs into one QNN context.
- "0": Default. Disabled.
- "1": Enabled.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed
// in case user need to merge/connect multiple EPContext nodes in one model
static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix";

// Share EP related resources across EPs
static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts";

// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
// Option values:
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
Expand Down
33 changes: 16 additions & 17 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace onnxruntime {
namespace qnn {

bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) {
// It's an Onnx model with Qnn context cache binary if it has a node with EPContext type and the source is QNN or QNNExecutionProvider.
// It's an Onnx model with Qnn context cache binary if it has a node with EPContext type
// and the source is QNN or QNNExecutionProvider.
for (const auto& node : graph_viewer.Nodes()) {
if (EPCONTEXT_OP == node.OpType()) {
NodeAttrHelper node_helper(node);
Expand Down Expand Up @@ -44,19 +45,14 @@ bool IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGr
}

Status GetMainContextNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
std::vector<int>& main_context_pos,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
std::vector<int>& main_context_pos) {
for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) {
// Only EPContext nodes are filtered in
// There is only one EPContext node in one filtered graph -- this is guaranteed by GetCapability
const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph);
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
const auto& ep_context_node = graph_viewer.Nodes().begin();
ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node.");
qnn_models.emplace(ep_context_node->Name(),
std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager));
NodeAttrHelper node_helper(*ep_context_node);
int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast<int64_t>(0));
if (1 == is_main_context) {
Expand Down Expand Up @@ -91,7 +87,8 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
const logging::Logger& logger,
QnnModelLookupTable& qnn_models) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
Expand All @@ -100,6 +97,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
static_cast<uint64_t>(context_binary.length()),
main_context_node.Name(),
logger,
qnn_models);
}

Expand Down Expand Up @@ -149,22 +147,23 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast<uint64_t>(buffer_size),
main_context_node.Name(),
logger,
qnn_models);
}

Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
QnnModelLookupTable& qnn_models,
const logging::Logger& logger) {
for (const auto& ep_context_node : graph_viewer.Nodes()) {
Status status = GetEpContextFromMainNode(ep_context_node, ctx_onnx_model_path, qnn_backend_manager, qnn_models);
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager,
logger, qnn_models);

// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage();
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage());
}
// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage();
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage());
}

return Status::OK();
Expand Down Expand Up @@ -197,7 +196,7 @@ Status CreateEPContextNodes(Model* model,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const QnnModelLookupTable& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger) {
Expand Down
11 changes: 5 additions & 6 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace qnn {

class QnnModel;
class QnnBackendManager;
using QnnModelLookupTable = std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>;

static const std::string EPCONTEXT_OP = "EPContext";
static const std::string MAIN_CONTEXT = "main_context";
Expand All @@ -33,10 +34,7 @@ bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer);
bool IsFusedGraphHasCtxNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs);

Status GetMainContextNode(const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
QnnBackendManager* qnn_backend_manager,
const logging::Logger& logger,
std::vector<int>& main_context_pos,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);
std::vector<int>& main_context_pos);

Status CreateNodeArgs(const std::vector<std::string>& names,
const std::unordered_map<std::string, OnnxTensorInfo>& tensor_info_table,
Expand All @@ -51,12 +49,13 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);
const logging::Logger& logger,
QnnModelLookupTable& qnn_models);

Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models,
QnnModelLookupTable& qnn_models,
const logging::Logger& logger);

Status CreateEPContextNodes(Model* model,
Expand Down
35 changes: 23 additions & 12 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// #include "GPU/QnnGpuCommon.h"
#include "DSP/QnnDspCommon.h"
#include "HTP/QnnHtpCommon.h"
#include "HTP/QnnHtpContext.h"
#include <gsl/gsl>
#include "core/framework/endian_utils.h"
#include "core/common/logging/capture.h"
Expand Down Expand Up @@ -208,6 +209,7 @@ Status QnnBackendManager::LoadQnnSystemLib() {
#else
std::string system_lib_file = "libQnnSystem.so";
#endif // #ifdef _WIN32
LOGS_DEFAULT(INFO) << "Loading QnnSystem lib";
std::filesystem::path lib_file_path(backend_path_.c_str());
std::string sys_file_path(lib_file_path.remove_filename().string() + system_lib_file);
QnnSystemInterface_t* system_interface_provider{nullptr};
Expand Down Expand Up @@ -520,9 +522,18 @@ Status QnnBackendManager::CreateContext() {
return Status::OK();
}

QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};
QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
QnnHtpContext_CustomConfig_t customConfig;
customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
customConfig.weightSharingEnabled = enable_htp_weight_sharing_;
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
context_config_weight_sharing.customConfig = &customConfig;

QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config));
const QnnContext_Config_t* context_configs[] = {&context_priority_config,
&context_config_weight_sharing,
nullptr};

Qnn_ContextHandle_t context = nullptr;
Qnn_ErrorHandle_t result = qnn_interface_.contextCreate(backend_handle_,
Expand Down Expand Up @@ -597,7 +608,8 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6

Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models) {
const logging::Logger& logger,
QnnModelLookupTable& qnn_models) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
Expand Down Expand Up @@ -631,7 +643,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
}

ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count << ", EPContext node count: " << qnn_models.size();
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;

ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
"Invalid function pointer for contextCreateFromBinary.");
Expand All @@ -653,15 +665,14 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
if (1 == graph_count) {
// in case the EPContext node is generated from script
// the graph name from the context binary may not match the EPContext node name
auto qnn_model_pos = qnn_models.find(node_name);
ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), node_name, " does not match any EPContext node names.");
ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[0], context));
auto qnn_model = std::make_unique<qnn::QnnModel>(logger, this);
ORT_RETURN_IF_ERROR(qnn_model->DeserializeGraphInfoFromBinaryInfo(graphs_info[0], context));
qnn_models.emplace(node_name, std::move(qnn_model));
} else {
for (uint32_t i = 0; i < graph_count; ++i) {
std::string graph_name(graphs_info[i].graphInfoV1.graphName);
auto qnn_model_pos = qnn_models.find(graph_name);
ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names.");
ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i], context));
auto qnn_model = std::make_unique<qnn::QnnModel>(logger, this);
ORT_RETURN_IF_ERROR(qnn_model->DeserializeGraphInfoFromBinaryInfo(graphs_info[i], context));
qnn_models.emplace(graphs_info[i].graphInfoV1.graphName, std::move(qnn_model));
}
}

Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class QnnBackendManager {
std::string&& qnn_saver_path,
uint32_t device_id,
QnnHtpDevice_Arch_t htp_arch,
uint32_t soc_model)
uint32_t soc_model,
bool enable_htp_weight_sharing)
: backend_path_(backend_path),
profiling_level_etw_(profiling_level_etw),
profiling_level_(profiling_level),
Expand All @@ -48,7 +49,8 @@ class QnnBackendManager {
qnn_saver_path_(qnn_saver_path),
device_id_(device_id),
htp_arch_(htp_arch),
soc_model_(soc_model) {
soc_model_(soc_model),
enable_htp_weight_sharing_(enable_htp_weight_sharing) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager);

Expand Down Expand Up @@ -89,6 +91,7 @@ class QnnBackendManager {

Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
const logging::Logger& logger,
std::unordered_map<std::string, std::unique_ptr<qnn::QnnModel>>& qnn_models);

Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
Expand Down Expand Up @@ -262,6 +265,7 @@ class QnnBackendManager {
uint32_t device_id_ = 0;
QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE;
uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN;
bool enable_htp_weight_sharing_ = false;
};

} // namespace qnn
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Status QnnModel::FinalizeGraphs() {
qnn_backend_manager_->GetQnnProfileHandle(),
nullptr);
if (QNN_GRAPH_NO_ERROR != status) {
LOGS(logger_, ERROR) << "Failed to finalize QNN graph.";
LOGS(logger_, ERROR) << "Failed to finalize QNN graph. Error code: " << status;
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to finalize QNN graph.");
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/qnn/builder/qnn_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class QnnModel {
return outputs_info_;
}

const std::string& Name() { return graph_info_->Name(); }
const std::string& Name() const { return graph_info_->Name(); }

private:
const NodeUnit& GetNodeUnit(const Node* node,
Expand Down
Loading

0 comments on commit 190588b

Please sign in to comment.