Skip to content

Commit

Permalink
Enable float32 model with FP16 precision for QNN HTP backend (#19863)
Browse files Browse the repository at this point in the history
### Description
Enable float32 model with FP16 precision for QNN HTP backend
  • Loading branch information
HectorSVC authored Mar 13, 2024
1 parent 6579f74 commit 60ad6c6
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 2 deletions.
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3619,6 +3619,10 @@ struct OrtApi {
* - "73"
* - "75"
* "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
"enable_htp_fp16_precision": Only used for float32 model.
Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision.
- "0": Default. With fp32 precision.
- "1": With fp16 precision.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,19 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
}
}

static const std::string QNN_HTP_FP16_MODE = "enable_htp_fp16_precision";
auto htp_fp16_mode_pos = provider_options_map.find(QNN_HTP_FP16_MODE);
if (htp_fp16_mode_pos != provider_options_map.end()) {
if ("1" == htp_fp16_mode_pos->second) {
enable_HTP_FP16_precision_ = true;
} else if ("0" == htp_fp16_mode_pos->second) {
enable_HTP_FP16_precision_ = false;
} else {
LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_fp16_precision: " << enable_HTP_FP16_precision_ << " only 0 or 1 allowed. Set to 0.";
}
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
}

qnn_backend_manager_ = std::make_unique<qnn::QnnBackendManager>(
std::move(backend_path),
profiling_level,
Expand Down Expand Up @@ -637,6 +650,16 @@ void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder<QnnGraph_C
graph_opt_config_vtcm.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM;
graph_opt_config_vtcm.customConfig = &htp_graph_opt_config_vtcm;
}

if (enable_HTP_FP16_precision_) {
QnnHtpGraph_CustomConfig_t& htp_graph_precision_config = configs_builder.PushCustomConfig();
htp_graph_precision_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_PRECISION;
htp_graph_precision_config.precision = QNN_PRECISION_FLOAT16;

QnnGraph_Config_t& graph_precision_config = configs_builder.PushConfig();
graph_precision_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM;
graph_precision_config.customConfig = &htp_graph_precision_config;
}
}
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class QNNExecutionProvider : public IExecutionProvider {
uint32_t device_id_ = 0;
qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault;
uint32_t default_rpc_control_latency_ = 0;
bool enable_HTP_FP16_precision_ = false;

class PerThreadContext final {
public:
Expand Down
13 changes: 12 additions & 1 deletion onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ void usage() {
"\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n"
"\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n"
"\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n"
"\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n"
"\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n"
"\t [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>' \n\n"
"\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n"
"\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n"
Expand Down Expand Up @@ -525,11 +527,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
std::string str = str_stream.str();
ORT_THROW("Wrong value for htp_arch. select from: " + str);
}
} else if (key == "enable_htp_fp16_precision") {
std::unordered_set<std::string> supported_options = {"0", "1"};
if (supported_options.find(value) == supported_options.end()) {
std::ostringstream str_stream;
std::copy(supported_options.begin(), supported_options.end(),
std::ostream_iterator<std::string>(str_stream, ","));
std::string str = str_stream.str();
ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str);
}
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path',
'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode',
'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority',
'soc_model', 'htp_arch', 'device_id'])");
'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision'])");
}

qnn_options[key] = value;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ namespace perftest {
"\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n"
"\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n"
"\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n"
"\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n"
"\t Otherwise, it will be fp32 precision. Only works for float32 model. Defaults to '0' (with FP32 precision.). \n"
"\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n"
"\n"
"\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n"
Expand Down
11 changes: 10 additions & 1 deletion onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,20 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
std::string str = str_stream.str();
ORT_THROW("Wrong value for htp_arch. select from: " + str);
}
} else if (key == "enable_htp_fp16_precision") {
std::unordered_set<std::string> supported_options = {"0", "1"};
if (supported_options.find(value) == supported_options.end()) {
std::ostringstream str_stream;
std::copy(supported_options.begin(), supported_options.end(),
std::ostream_iterator<std::string>(str_stream, ","));
std::string str = str_stream.str();
ORT_THROW("Wrong value for enable_htp_fp16_precision. select from: " + str);
}
} else {
ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path',
'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode',
'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model',
'htp_arch', 'device_id'])");
'htp_arch', 'device_id', 'enable_htp_fp16_precision'])");
}

qnn_options[key] = value;
Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,25 @@ TEST_F(QnnHTPBackendTests, DISABLED_CastAddHTPAccuracyTest) {
ExpectedEPNodeAssignment::All);
}

// Test float32 model with FP16 precision
TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif
provider_options["enable_htp_fp16_precision"] = "1";

auto input_defs = {TestInputDef<float>({1, 2, 2, 2}, false, -10.0f, 10.0f),
TestInputDef<float>({1, 2, 2, 2}, false, -10.0f, 10.0f)};
RunQnnModelTest(BuildOpTestCase<float>("Add", input_defs, {}, {}, kOnnxDomain),
provider_options,
13,
ExpectedEPNodeAssignment::All,
0.008f);
}

#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
#endif // !defined(ORT_MINIMAL_BUILD)

Expand Down

0 comments on commit 60ad6c6

Please sign in to comment.