Skip to content

Commit

Permalink
[Migraphx EP] Static int8 QDQ support (#17931)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Adding static int8 quantization support for MIGraphX Execution Provider

- Allows for parsing in calibration tables generated by Onnxruntime or
TensorRT's toolsets
- Add proper environment variables into the MIGraphX EP
- Update python API to include updating execution provider flags -> was
missing on python side
- Hook into MIGraphX's int8 quantitation and optimization of models

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Required so that we can get onnxruntime to pass in models while
leveraging the existing tooling for int8 static QDQ quantization.

First step in a series of PRs which will add further static quantization
on the operator level as MIGraphX releases further support.

These changes drew heavily from the tensorRT EP should allow for similar
functionality for GPU based (versus CPU) quantization of models before
an inference is performed.

---------

Co-authored-by: Ted Themistokleous <[email protected]>
Co-authored-by: Ted Themistokleous <[email protected]>
  • Loading branch information
3 people authored Nov 9, 2023
1 parent 55c19d6 commit 8d50313
Show file tree
Hide file tree
Showing 12 changed files with 671 additions and 61 deletions.
8 changes: 5 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,11 @@ typedef struct OrtTensorRTProviderOptions {
* \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
*/
typedef struct OrtMIGraphXProviderOptions {
int device_id; // hip device id.
int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true
int device_id; // hip device id.
int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/migraphx/migraphx_call.cc
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include <unistd.h>
#include <string.h>
#include <miopen/miopen.h>
#include <rocblas/rocblas.h>
#include "migraphx_call.h"
#include "core/common/common.h"
#include "core/common/status.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_call.h"

namespace onnxruntime {

Expand Down
281 changes: 251 additions & 30 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Large diffs are not rendered by default.

27 changes: 21 additions & 6 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,29 @@

#pragma once

#include <miopen/miopen.h>
#include <rocblas/rocblas.h>

#include "core/framework/arena_extend_strategy.h"
#include "core/framework/execution_provider.h"
#include "core/platform/ort_mutex.h"
#include "migraphx_execution_provider_info.h"
#include "core/providers/migraphx/migraphx_execution_provider_info.h"
#include "core/providers/migraphx/migraphx_inc.h"

#include <map>
#include "migraphx_inc.h"
#include <unordered_map>
// TODO: find a better way to share this
// #include "core/providers/cuda/rocm_stream_handle.h"
#include <miopen/miopen.h>
#include <rocblas/rocblas.h>

namespace onnxruntime {

namespace migraphx_env_vars {
static const std::string kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE";
static const std::string dumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE";
static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE";
static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
}; // namespace migraphx_env_vars

// Information to construct kernel function state.
Expand All @@ -35,6 +41,9 @@ struct MIGraphXFuncState {
OrtMutex* mgx_mu_ptr = nullptr;
bool no_input_shape = false;
bool fp16_enable = false;
bool int8_enable = false;
bool int8_calibration_cache_available = false;
std::unordered_map<std::string, float> dynamic_range_map;
bool dump_model_ops = false;
};

Expand Down Expand Up @@ -69,6 +78,12 @@ class MIGraphXExecutionProvider : public IExecutionProvider {

private:
bool fp16_enable_ = false;
bool int8_enable_ = false;
std::string int8_calibration_cache_name_;
bool int8_calibration_cache_available_ = false;
bool int8_use_native_migraphx_calibration_table_ = false;
std::string calibration_cache_path_;
std::unordered_map<std::string, float> dynamic_range_map;
bool dump_model_ops_ = false;
int device_id_;
migraphx::target t_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ namespace migraphx {
namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kFp16Enable = "trt_fp16_enable";
constexpr const char* kInt8Enable = "trt_int8_enable";
constexpr const char* kInt8Enable = "migx_int8_enable";
constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name";
constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table";

} // namespace provider_option_names
} // namespace migraphx

Expand Down Expand Up @@ -45,15 +48,17 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}};
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
};
return options;
}

ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) {
const ProviderOptions options{
{migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}};
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
};
return options;
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <limits>
#include <string>

#include "core/framework/ortdevice.h"
#include "core/framework/provider_options.h"
Expand All @@ -16,6 +17,8 @@ struct MIGraphXExecutionProviderInfo {
int device_id{0};
bool fp16_enable{false};
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};

static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,20 @@
// Licensed under the MIT License

#pragma once

#include <fstream>
#include <unordered_map>
#include <string>
#include <iostream>
#include <filesystem>
#include <memory>
#include "flatbuffers/idl.h"
#include "core/providers/migraphx/ort_trt_int8_cal_table.fbs.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/framework/execution_provider.h"
#include "core/common/path_string.h"

namespace fs = std::filesystem;

namespace onnxruntime {

Expand Down Expand Up @@ -101,7 +113,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector
return true;
}

bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector<std::size_t> indices, std::vector<NodeIndex>& input_nodes) {
bool canEvalNodeArgument(const GraphViewer& graph,
const Node* node,
std::vector<std::size_t> indices,
std::vector<NodeIndex>& input_nodes) {
input_nodes.clear();
std::vector<const Node*> in_nodes;
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) {
Expand Down Expand Up @@ -137,4 +152,102 @@ bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector
return true;
}

float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) {
int s = (input >> 31) & 0x01;
int e = ((input & 0x7f800000) >> 23) - 127;
int p = -1;
double m = 0.0;
for (int i = 0; i < 23; ++i) {
m += ((input >> (23 - i - 1)) & 0x01) * pow(2.0, p--);
}
return static_cast<float>((s ? -1 : 1) * pow(2.0, e) * (m + 1.0));
}

/*
* Read calibration table for INT8 quantization
* Two kind of calibration tables are supported,
* 1. ORT generated calibration table
* The table is pre-serialized by flatbuffers.
* Each entry in the table is a key-value pair,
* key: tensor name, value: maximum absolute value in floating point
* For example,
* data_0 2.008338
* ...
* 2. Native TensorRT generated calibration table
* Data format is defined by TensorRT as,
* tensor name : scale in 32-bit single precision IEEE754 format
* For example,
* TRT-7103-EntropyCalibration2
* data_0: 4000889d
* ...
*
* Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models
*
*/
bool ReadDynamicRange(const std::string file_name,
const bool is_calibration_table,
std::unordered_map<std::string,
float>& dynamic_range_map) {
std::ifstream infile(file_name, std::ios::binary | std::ios::in);
if (!infile) {
return false;
}

if (is_calibration_table) {
// Native TensorRT generated calibration table
std::string line;
char delim = ':';
if (std::getline(infile, line)) {
std::istringstream first_line(line);
std::string version;
std::getline(first_line, version, delim);
std::size_t found = version.find("TRT-");
if (found != std::string::npos) {
while (std::getline(infile, line)) {
std::istringstream in_line(line);
std::string str;
std::getline(in_line, str, delim);
std::string tensor_name = str;
std::getline(in_line, str, delim);
uint32_t scale_int = std::strtoul(str.c_str(), nullptr, 16);
float scale_float = ConvertSinglePrecisionIEEE754ToFloat(scale_int);
float dynamic_range = scale_float * 127.0f;
dynamic_range_map[tensor_name] = dynamic_range;
}
} else {
throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name);
}
}
} else {
// ORT generated calibration table
infile.seekg(0, std::ios::end);
size_t length = infile.tellg();
infile.seekg(0, std::ios::beg);
std::unique_ptr<char[]> data{new char[length]};
infile.read(reinterpret_cast<char*>(data.get()), length);
infile.close();
auto flat_table = flatbuffers::GetRoot<CalTableFlatBuffers::TrtTable>(reinterpret_cast<char*>(data.get()));
auto flat_dict = flat_table->dict();
for (size_t i = 0, end = flat_dict->size(); i < end; ++i) {
flatbuffers::uoffset_t idx = static_cast<flatbuffers::uoffset_t>(i);
dynamic_range_map[flat_dict->Get(idx)->key()->str()] = std::stof(flat_dict->Get(idx)->value()->str());
}
}
return true;
}

/*
* Get cache by name
*
*/
std::string GetCachePath(const std::string& root, const std::string& name) {
if (root.empty()) {
return name;
} else {
fs::path path = root;
path.append(name);
return path.string();
}
}

} // namespace onnxruntime
32 changes: 27 additions & 5 deletions onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License
#include <atomic>

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/migraphx/migraphx_provider_factory.h"
Expand All @@ -8,7 +9,6 @@
#include "hip_allocator.h"
#include "gpu_data_transfer.h"
#include "core/framework/provider_options.h"
#include <atomic>

#include "core/session/onnxruntime_c_api.h"

Expand Down Expand Up @@ -48,15 +48,37 @@ struct MIGraphX_Provider : Provider {
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = "";
if (options.migraphx_int8_calibration_table_name != nullptr) {
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name;
}
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
return std::make_shared<MIGraphXProviderFactory>(info);
}

void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override {
auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options);
auto& trt_options = *reinterpret_cast<OrtMIGraphXProviderOptions*>(provider_options);
trt_options.device_id = internal_options.device_id;
trt_options.migraphx_fp16_enable = internal_options.fp16_enable;
trt_options.migraphx_int8_enable = internal_options.int8_enable;
auto& migx_options = *reinterpret_cast<OrtMIGraphXProviderOptions*>(provider_options);
migx_options.device_id = internal_options.device_id;
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
migx_options.migraphx_int8_enable = internal_options.int8_enable;

char* dest = nullptr;
auto str_size = internal_options.int8_calibration_table_name.size();
if (str_size == 0) {
migx_options.migraphx_int8_calibration_table_name = nullptr;
} else {
dest = new char[str_size + 1];
#ifdef _MSC_VER
strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size);
#else
strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size);
#endif
dest[str_size] = '\0';
migx_options.migraphx_int8_calibration_table_name = (const char*)dest;
}

migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
Loading

0 comments on commit 8d50313

Please sign in to comment.