diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c7d4a236bcf89..cddad732104ed 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -599,9 +599,11 @@ typedef struct OrtTensorRTProviderOptions { * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX */ typedef struct OrtMIGraphXProviderOptions { - int device_id; // hip device id. - int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true - int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true + int device_id; // hip device id. + int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index cd947420b7615..5248ac2f39214 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -1,14 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared_library/provider_api.h" #include #include #include #include -#include "migraphx_call.h" #include "core/common/common.h" #include "core/common/status.h" +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/migraphx/migraphx_call.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index d2538544db60e..d1b3f19100942 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,5 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License +#include +#include +#include +#include +#include #include "core/providers/shared_library/provider_api.h" #define ORT_API_MANUAL_INIT @@ -12,10 +17,6 @@ #include "gpu_data_transfer.h" #include "migraphx_inc.h" -#include -#include -#include - // TODO: find a better way to share this #include "core/providers/rocm/rocm_stream_handle.h" @@ -113,6 +114,45 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } + // whether int8 is enabled + const std::string int8_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8Enable); + if (!int8_enable_env.empty()) { + int8_enable_ = (std::stoi(int8_enable_env) == 0 ? false : true); + } + + if (int8_enable_) { + const std::string int8_calibration_cache_name_env = + onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName); + if (!int8_calibration_cache_name_env.empty()) { + int8_calibration_cache_name_ = int8_calibration_cache_name_env; + } + + const std::string cache_path = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kCachePath); + if (!cache_path.empty()) { + calibration_cache_path_ = cache_path; + } + + const std::string int8_use_native_migraphx_calibration_table_env = + onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable); + if (!int8_use_native_migraphx_calibration_table_env.empty()) { + int8_use_native_migraphx_calibration_table_ = + (std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true); + } + } + + if (int8_enable_) { + int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if (int8_enable_ && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(calibration_cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_migraphx_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + } + } + // dump unsupported ops const std::string dump_model_ops_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::dumpModelOps); if (!dump_model_ops_env.empty()) { @@ -124,6 +164,15 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv MIOPEN_CALL_THROW(miopenCreate(&external_miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(external_miopen_handle_, stream_)); + + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " + << "device_id: " << device_id_ + << ", migraphx_fp16_enable: " << fp16_enable_ + << ", migraphx_int8_enable: " << int8_enable_ + << ", dump_model_ops: " << dump_model_ops_ + << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ + << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ + << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_; } MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { @@ -467,7 +516,8 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return false; } -void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, const logging::Logger& logger) { +void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, + const logging::Logger& logger) { // Then check whether a subgraph should fallback to CPU // 1. Check whether a subgraph contains a RNN operator std::unordered_set rnn_names = {"RNN", "GRU", "LSTM"}; @@ -642,7 +692,8 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st fused_inputs.erase(iter); erased.insert(output); } else if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), graph_output_names.end(), output->Name()) != graph_output_names.end()) { + if (std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } fused_outputs[output] = output_order++; @@ -660,7 +711,8 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st } // Only when output is neither in input list nor erased list, add the output to output list else if (erased.find(output) == erased.end()) { - if (std::find(graph_output_names.begin(), graph_output_names.end(), output->Name()) != graph_output_names.end()) { + if (std::find(graph_output_names.begin(), + graph_output_names.end(), output->Name()) != graph_output_names.end()) { graph_outputs_to_add[output] = output_order; } fused_outputs[output] = output_order++; @@ -733,31 +785,156 @@ static std::vector GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, /*out*/ std::unordered_set& mgx_required_initializers, const logging::Logger& logger) { - static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "And", - "ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "ATen", "AveragePool", - "BatchNormalization", "Cast", "Ceil", "Celu", "Clip", "Concat", "Constant", "ConstantFill", - "ConstantOfShape", "Conv", "ConvInteger", "ConvTranspose", "Cos", "Cosh", "CumSum", - "DepthToSpace", "DequantizeLinear", "Div", "Dropout", "Elu", "Equal", "Erf", "Exp", - "Expand", "EyeLike", "Flatten", "Floor", "GRU", "Gather", "GatherElements", "GatherND", "Gemm", "GlobalAveragePool", - "GlobalMaxPool", "Greater", "GreaterOrEqual", "HardSigmoid", "HardSwish", "Identity", - "If", "ImageScaler", "InstanceNormalization", "IsNan", "LeakyRelu", "Less", "LessOrEqual", - "Log", "LogSoftmax", "Loop", "LpNormalization", "LRN", "LSTM", "MatMul", "MatMulInteger", "Max", "MaxPool", - "Mean", "Min", "Mod", "Mul", "Multinomial", "Neg", "NonMaxSuppression", "NonZero", "Not", - "OneHot", "Or", "Pad", "Pow", "PRelu", "QuantizeLinear", "RandomNormal", "RandomNormalLike", - "RandomUniform", "RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2", - "ReduceLogSum", "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", - "ReduceSum", "ReduceSumSquare", "Relu", "Reshape", "Resize", "ReverseSequence", "RNN", "Roialign", "Round", - "Scatter", "ScatterElements", "ScatterND", "Selu", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Softplus", - "Softsign", "SpaceToDepth", "Split", "Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh", - "ThresholdedRelu", "Tile", "TopK", "Transpose", "Trilu", "Unsqueeze", "Upsample", "Where", "Xor"}; + static std::set mgx_supported_ops = {"Abs", + "Acos", + "Acosh", + "Add", + "And", + "ArgMax", + "ArgMin", + "Asin", + "Asinh", + "Atan", + "Atanh", + "ATen", + "AveragePool", + "BatchNormalization", + "Cast", + "Ceil", + "Celu", + "Clip", + "Concat", + "Constant", + "ConstantFill", + "ConstantOfShape", + "Conv", + "ConvInteger", + "ConvTranspose", + "Cos", + "Cosh", + "CumSum", + "DepthToSpace", + "DequantizeLinear", + "Div", + "Dropout", + "Elu", + "Equal", + "Erf", + "Exp", + "Expand", + "EyeLike", + "Flatten", + "Floor", + "GRU", + "Gather", + "GatherElements", + "GatherND", + "Gemm", + "GlobalAveragePool", + "GlobalMaxPool", + "Greater", + "GreaterOrEqual", + "HardSigmoid", + "HardSwish", + "Identity", + "If", + "ImageScaler", + "InstanceNormalization", + "IsNan", + "LeakyRelu", + "Less", + "LessOrEqual", + "Log", + "LogSoftmax", + "Loop", + "LpNormalization", + "LRN", + "LSTM", + "MatMul", + "MatMulInteger", + "Max", + "MaxPool", + "Mean", + "Min", + "Mod", + "Mul", + "Multinomial", + "Neg", + "NonMaxSuppression", + "NonZero", + "Not", + "OneHot", + "Or", + "Pad", + "Pow", + "PRelu", + "QLinearAdd", + "QLinearConv", + "QLinearMatMul", + "QuantizeLinear", + "RandomNormal", + "RandomNormalLike", + "RandomUniform", + "RandomUniformLike", + "Range", + "Reciprocal", + "ReduceL1", + "ReduceL2", + "ReduceLogSum", + "ReduceLogSumExp", + "ReduceMax", + "ReduceMean", + "ReduceMin", + "ReduceProd", + "ReduceSum", + "ReduceSumSquare", + "Relu", + "Reshape", + "Resize", + "ReverseSequence", + "RNN", + "Roialign", + "Round", + "Scatter", + "ScatterElements", + "ScatterND", + "Selu", + "Shape", + "Sigmoid", + "Sign", + "Sin", + "Sinh", + "Slice", + "Softmax", + "Softplus", + "Softsign", + "SpaceToDepth", + "Split", + "Sqrt", + "Squeeze", + "Sub", + "Sum", + "Tan", + "Tanh", + "ThresholdedRelu", + "Tile", + "TopK", + "Transpose", + "Trilu", + "Unsqueeze", + "Upsample", + "Where", + "Xor"}; std::vector unsupported_nodes_idx; for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) { if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) { // Collect inputs that are initializers - graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) { + graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, + &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) { if(is_input && graph_viewer.GetAllInitializedTensors().count(node_arg.Name())) { mgx_required_initializers.insert(node_arg.Name()); - } }, true); + } }, + true); } else { unsupported_nodes_idx.push_back(node_idx); } @@ -770,7 +947,8 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, // is split into 3 parts. supported_cluster + (UNsupported_node + rest_of_the_graph). // This functions returns vector of all supported_subgraphx by amdmigraphx static std::vector> -GetPartitionedSubgraphs(const std::vector& topological_order, const std::vector& unsupported_nodes) { +GetPartitionedSubgraphs(const std::vector& topological_order, + const std::vector& unsupported_nodes) { std::vector> mgx_subgraphx; auto prev = topological_order.begin(); @@ -948,6 +1126,24 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::quantize_fp16(prog); } + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if (int8_enable_ && int8_calibration_cache_available_) { + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; + + auto param_shapes = prog.get_parameter_shapes(); + + for (auto&& name : param_shapes.names()) { + auto dynamic_range_i = dynamic_range_map.find(name); + if (dynamic_range_i != dynamic_range_map.end()) { + quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second))); + } + } + + quant_opts.add_calibration_data(quant_params); + // perform static quantization on the programs + migraphx::quantize_int8(prog, t_, quant_opts); + } prog.compile(t_); auto prog_output_shapes = prog.get_output_shapes(); for (std::size_t i = 0; i < output_names.size(); ++i) { @@ -967,7 +1163,8 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, - map_no_input_shape_[context->node_name], fp16_enable_, dump_model_ops_}; + map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_, + int8_calibration_cache_available_, dynamic_range_map, dump_model_ops_}; *state = p.release(); return 0; }; @@ -982,12 +1179,15 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& MIGraphXFuncState* mgx_state = reinterpret_cast(state); std::unordered_map& map_input_name_index = mgx_state->input_name_indexes; + std::unordered_map& map_dynamic_range = mgx_state->dynamic_range_map; migraphx::target t = mgx_state->t; migraphx::program& prog = mgx_state->prog; std::string& onnx_string = mgx_state->onnx_string; migraphx::onnx_options& cmp_options = mgx_state->options; bool& no_input_shape = mgx_state->no_input_shape; bool fp16_enable = mgx_state->fp16_enable; + bool int8_enable = mgx_state->int8_enable; + bool int8_calibration_cache_available = mgx_state->int8_calibration_cache_available; // mean no program at all, so need to get the input shape info // from input data @@ -1043,6 +1243,25 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::quantize_fp16(prog); } + // Read in the calibration data and map it to an migraphx paramater map for the calibration ops + if (int8_enable && int8_calibration_cache_available) { + migraphx::quantize_int8_options quant_opts; + migraphx::program_parameters quant_params; + + auto param_shapes = prog.get_parameter_shapes(); + + for (auto&& name : param_shapes.names()) { + auto dynamic_range_i = map_dynamic_range.find(name); + if (dynamic_range_i != map_dynamic_range.end()) { + quant_params.add(name, migraphx::argument(param_shapes[name], &(dynamic_range_i->second))); + } + } + + quant_opts.add_calibration_data(quant_params); + // perform static quantization on the programs + migraphx::quantize_int8(prog, t, quant_opts); + } + prog.compile(t); mgx_state->prog = prog; param_shapes = prog.get_parameter_shapes(); @@ -1137,9 +1356,11 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& return Status::OK(); } -void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const { +void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, + AllocatorMap& allocators) const { auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)]; - RegisterRocmStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, false /*TODO:external_stream_*/, external_miopen_handle_, external_rocblas_handle_); + RegisterRocmStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, + false /*TODO:external_stream_*/, external_miopen_handle_, external_rocblas_handle_); } OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 1f591f9a1c0a5..c094be51012e4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,23 +3,29 @@ #pragma once +#include +#include + #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" -#include "migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_execution_provider_info.h" +#include "core/providers/migraphx/migraphx_inc.h" #include -#include "migraphx_inc.h" +#include // TODO: find a better way to share this // #include "core/providers/cuda/rocm_stream_handle.h" -#include -#include namespace onnxruntime { namespace migraphx_env_vars { -static const std::string kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"; -static const std::string dumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; +static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE"; +static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE"; +static const char dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS"; +static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"; +static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH"; +static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"; }; // namespace migraphx_env_vars // Information to construct kernel function state. @@ -35,6 +41,9 @@ struct MIGraphXFuncState { OrtMutex* mgx_mu_ptr = nullptr; bool no_input_shape = false; bool fp16_enable = false; + bool int8_enable = false; + bool int8_calibration_cache_available = false; + std::unordered_map dynamic_range_map; bool dump_model_ops = false; }; @@ -69,6 +78,12 @@ class MIGraphXExecutionProvider : public IExecutionProvider { private: bool fp16_enable_ = false; + bool int8_enable_ = false; + std::string int8_calibration_cache_name_; + bool int8_calibration_cache_available_ = false; + bool int8_use_native_migraphx_calibration_table_ = false; + std::string calibration_cache_path_; + std::unordered_map dynamic_range_map; bool dump_model_ops_ = false; int device_id_; migraphx::target t_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index bdf8388e75c15..b7d7a77853df6 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -14,7 +14,10 @@ namespace migraphx { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kFp16Enable = "trt_fp16_enable"; -constexpr const char* kInt8Enable = "trt_int8_enable"; +constexpr const char* kInt8Enable = "migx_int8_enable"; +constexpr const char* kInt8CalibTable = "migx_int8_calibration_table_name"; +constexpr const char* kInt8UseNativeCalibTable = "migx_int8_use_native_calibration_table"; + } // namespace provider_option_names } // namespace migraphx @@ -45,7 +48,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}}; + {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, + }; return options; } @@ -53,7 +57,8 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap const ProviderOptions options{ {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, - {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}}; + {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, + }; return options; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 472d418c9099c..18ac30fdc1283 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/framework/ortdevice.h" #include "core/framework/provider_options.h" @@ -16,6 +17,8 @@ struct MIGraphXExecutionProviderInfo { int device_id{0}; bool fp16_enable{false}; bool int8_enable{false}; + std::string int8_calibration_table_name{""}; + bool int8_use_native_calibration_table{false}; static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index fb0be15986111..071070e92a209 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -2,8 +2,20 @@ // Licensed under the MIT License #pragma once + +#include +#include +#include +#include +#include +#include +#include "flatbuffers/idl.h" +#include "core/providers/migraphx/ort_trt_int8_cal_table.fbs.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/execution_provider.h" +#include "core/common/path_string.h" + +namespace fs = std::filesystem; namespace onnxruntime { @@ -101,7 +113,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector return true; } -bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector indices, std::vector& input_nodes) { +bool canEvalNodeArgument(const GraphViewer& graph, + const Node* node, + std::vector indices, + std::vector& input_nodes) { input_nodes.clear(); std::vector in_nodes; for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) { @@ -137,4 +152,102 @@ bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector return true; } +float ConvertSinglePrecisionIEEE754ToFloat(uint32_t input) { + int s = (input >> 31) & 0x01; + int e = ((input & 0x7f800000) >> 23) - 127; + int p = -1; + double m = 0.0; + for (int i = 0; i < 23; ++i) { + m += ((input >> (23 - i - 1)) & 0x01) * pow(2.0, p--); + } + return static_cast((s ? -1 : 1) * pow(2.0, e) * (m + 1.0)); +} + +/* + * Read calibration table for INT8 quantization + * Two kind of calibration tables are supported, + * 1. ORT generated calibration table + * The table is pre-serialized by flatbuffers. + * Each entry in the table is a key-value pair, + * key: tensor name, value: maximum absolute value in floating point + * For example, + * data_0 2.008338 + * ... + * 2. Native TensorRT generated calibration table + * Data format is defined by TensorRT as, + * tensor name : scale in 32-bit single precision IEEE754 format + * For example, + * TRT-7103-EntropyCalibration2 + * data_0: 4000889d + * ... + * + * Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models + * + */ +bool ReadDynamicRange(const std::string file_name, + const bool is_calibration_table, + std::unordered_map& dynamic_range_map) { + std::ifstream infile(file_name, std::ios::binary | std::ios::in); + if (!infile) { + return false; + } + + if (is_calibration_table) { + // Native TensorRT generated calibration table + std::string line; + char delim = ':'; + if (std::getline(infile, line)) { + std::istringstream first_line(line); + std::string version; + std::getline(first_line, version, delim); + std::size_t found = version.find("TRT-"); + if (found != std::string::npos) { + while (std::getline(infile, line)) { + std::istringstream in_line(line); + std::string str; + std::getline(in_line, str, delim); + std::string tensor_name = str; + std::getline(in_line, str, delim); + uint32_t scale_int = std::strtoul(str.c_str(), nullptr, 16); + float scale_float = ConvertSinglePrecisionIEEE754ToFloat(scale_int); + float dynamic_range = scale_float * 127.0f; + dynamic_range_map[tensor_name] = dynamic_range; + } + } else { + throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + } + } + } else { + // ORT generated calibration table + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read(reinterpret_cast(data.get()), length); + infile.close(); + auto flat_table = flatbuffers::GetRoot(reinterpret_cast(data.get())); + auto flat_dict = flat_table->dict(); + for (size_t i = 0, end = flat_dict->size(); i < end; ++i) { + flatbuffers::uoffset_t idx = static_cast(i); + dynamic_range_map[flat_dict->Get(idx)->key()->str()] = std::stof(flat_dict->Get(idx)->value()->str()); + } + } + return true; +} + +/* + * Get cache by name + * + */ +std::string GetCachePath(const std::string& root, const std::string& name) { + if (root.empty()) { + return name; + } else { + fs::path path = root; + path.append(name); + return path.string(); + } +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 8358ca5fcda95..f985682ddc735 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" @@ -8,7 +9,6 @@ #include "hip_allocator.h" #include "gpu_data_transfer.h" #include "core/framework/provider_options.h" -#include #include "core/session/onnxruntime_c_api.h" @@ -48,15 +48,37 @@ struct MIGraphX_Provider : Provider { info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; info.int8_enable = options.migraphx_int8_enable; + info.int8_calibration_table_name = ""; + if (options.migraphx_int8_calibration_table_name != nullptr) { + info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name; + } + info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0; return std::make_shared(info); } void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); - auto& trt_options = *reinterpret_cast(provider_options); - trt_options.device_id = internal_options.device_id; - trt_options.migraphx_fp16_enable = internal_options.fp16_enable; - trt_options.migraphx_int8_enable = internal_options.int8_enable; + auto& migx_options = *reinterpret_cast(provider_options); + migx_options.device_id = internal_options.device_id; + migx_options.migraphx_fp16_enable = internal_options.fp16_enable; + migx_options.migraphx_int8_enable = internal_options.int8_enable; + + char* dest = nullptr; + auto str_size = internal_options.int8_calibration_table_name.size(); + if (str_size == 0) { + migx_options.migraphx_int8_calibration_table_name = nullptr; + } else { + dest = new char[str_size + 1]; +#ifdef _MSC_VER + strncpy_s(dest, str_size + 1, internal_options.int8_calibration_table_name.c_str(), str_size); +#else + strncpy(dest, internal_options.int8_calibration_table_name.c_str(), str_size); +#endif + dest[str_size] = '\0'; + migx_options.migraphx_int8_calibration_table_name = (const char*)dest; + } + + migx_options.migraphx_use_native_calibration_table = internal_options.int8_use_native_calibration_table; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h b/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h new file mode 100644 index 0000000000000..9639040f772da --- /dev/null +++ b/onnxruntime/core/providers/migraphx/ort_trt_int8_cal_table.fbs.h @@ -0,0 +1,145 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORT_TRT_INT8_CAL_TABLE_FBS_H_ +#define ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORT_TRT_INT8_CAL_TABLE_FBS_H_ + +#include +#include "flatbuffers/flatbuffers.h" + +namespace CalTableFlatBuffers { + +struct KeyValue; +struct KeyValueBuilder; + +struct TrtTable; +struct TrtTableBuilder; + +struct KeyValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef KeyValueBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEY = 4, + VT_VALUE = 6 + }; + const flatbuffers::String* key() const { + return GetPointer(VT_KEY); + } + bool KeyCompareLessThan(const KeyValue* o) const { + return *key() < *o->key(); + } + int KeyCompareWithValue(const char* val) const { + return strcmp(key()->c_str(), val); + } + const flatbuffers::String* value() const { + return GetPointer(VT_VALUE); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffsetRequired(verifier, VT_KEY) && + verifier.VerifyString(key()) && + VerifyOffset(verifier, VT_VALUE) && + verifier.VerifyString(value()) && + verifier.EndTable(); + } +}; + +struct KeyValueBuilder { + typedef KeyValue Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_key(flatbuffers::Offset key) { + fbb_.AddOffset(KeyValue::VT_KEY, key); + } + void add_value(flatbuffers::Offset value) { + fbb_.AddOffset(KeyValue::VT_VALUE, value); + } + explicit KeyValueBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + KeyValueBuilder& operator=(const KeyValueBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + fbb_.Required(o, KeyValue::VT_KEY); + return o; + } +}; + +inline flatbuffers::Offset CreateKeyValue( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset key = 0, + flatbuffers::Offset value = 0) { + KeyValueBuilder builder_(_fbb); + builder_.add_value(value); + builder_.add_key(key); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateKeyValueDirect( + flatbuffers::FlatBufferBuilder& _fbb, + const char* key = nullptr, + const char* value = nullptr) { + auto key__ = key ? _fbb.CreateString(key) : 0; + auto value__ = value ? _fbb.CreateString(value) : 0; + return CalTableFlatBuffers::CreateKeyValue( + _fbb, + key__, + value__); +} + +struct TrtTable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TrtTableBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DICT = 4 + }; + const flatbuffers::Vector>* dict() const { + return GetPointer>*>(VT_DICT); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DICT) && + verifier.VerifyVector(dict()) && + verifier.VerifyVectorOfTables(dict()) && + verifier.EndTable(); + } +}; + +struct TrtTableBuilder { + typedef TrtTable Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_dict(flatbuffers::Offset>> dict) { + fbb_.AddOffset(TrtTable::VT_DICT, dict); + } + explicit TrtTableBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TrtTableBuilder& operator=(const TrtTableBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTrtTable( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset>> dict = 0) { + TrtTableBuilder builder_(_fbb); + builder_.add_dict(dict); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTrtTableDirect( + flatbuffers::FlatBufferBuilder& _fbb, + std::vector>* dict = nullptr) { + auto dict__ = dict ? _fbb.CreateVectorOfSortedTables(dict) : 0; + return CalTableFlatBuffers::CreateTrtTable( + _fbb, + dict__); +} + +} // namespace CalTableFlatBuffers + +#endif // ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORT_TRT_INT8_CAL_TABLE_FBS_H_ diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 2027b592326df..56312898b0d16 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -730,33 +730,115 @@ std::unique_ptr CreateExecutionProviderInstance( } } } - LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please reference https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#requirements to ensure all dependencies are met."; + LOGS_DEFAULT(WARNING) << "Failed to create " + << type + << ". Please reference " + << "https://onnxruntime.ai/docs/execution-providers/" + << "TensorRT-ExecutionProvider.html#requirements to ensure all dependencies are met."; #endif } else if (type == kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX - return onnxruntime::MIGraphXProviderFactoryCreator::Create(0)->CreateProvider(); + std::string calibration_table; + auto it = provider_options_map.find(type); + if (it != provider_options_map.end()) { + OrtMIGraphXProviderOptions params{ + 0, + 0, + 0, + 0, + nullptr}; + for (auto option : it->second) { + if (option.first == "device_id") { + if (!option.second.empty()) { + params.device_id = std::stoi(option.second); + } else { + ORT_THROW("[ERROR] [MIGraphX] The value for the key 'device_id' should be a number i.e. '0'.\n"); + } + } else if (option.first == "migraphx_fp16_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_fp16_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_fp16_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "migraphx_int8_enable") { + if (option.second == "True" || option.second == "true") { + params.migraphx_int8_enable = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_int8_enable = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else if (option.first == "migraphx_int8_calibration_table_name") { + if (!option.second.empty()) { + calibration_table = option.second; + params.migraphx_int8_calibration_table_name = calibration_table.c_str(); + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a " + "file name i.e. 'cal_table'.\n"); + } + } else if (option.first == "migraphx_use_native_calibration_table") { + if (option.second == "True" || option.second == "true") { + params.migraphx_use_native_calibration_table = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_use_native_calibration_table = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } + } else { + ORT_THROW("Invalid MIGraphX EP option: ", option.first); + } + } + if (std::shared_ptr migraphx_provider_factory = + onnxruntime::MIGraphXProviderFactoryCreator::Create(¶ms)) { + return migraphx_provider_factory->CreateProvider(); + } + } else { + if (std::shared_ptr migraphx_provider_factory = + onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) { + return migraphx_provider_factory->CreateProvider(); + } + } #endif } else if (type == kCudaExecutionProvider) { #ifdef USE_CUDA - // If the environment variable 'CUDA_UNAVAILABLE' exists, then we do not load cuda. This is set by _ld_preload for the manylinux case - // as in that case, trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies. + // If the environment variable 'CUDA_UNAVAILABLE' exists, then we do not load cuda. + // This is set by _ld_preload for the manylinux case as in that case, + // trying to load the library itself will result in a crash due to the way that auditwheel strips dependencies. if (Env::Default().GetEnvironmentVar("ORT_CUDA_UNAVAILABLE").empty()) { if (auto* cuda_provider_info = TryGetProviderInfo_CUDA()) { const CUDAExecutionProviderInfo info = GetCudaExecutionProviderInfo(cuda_provider_info, provider_options_map); - // This variable is never initialized because the APIs by which it should be initialized are deprecated, however they still - // exist are are in-use. Neverthless, it is used to return CUDAAllocator, hence we must try to initialize it here if we can - // since FromProviderOptions might contain external CUDA allocator. + // This variable is never initialized because the APIs by which it should be initialized are deprecated, + // however they still exist are are in-use. Neverthless, it is used to return CUDAAllocator, + // hence we must try to initialize it here if we can since FromProviderOptions might contain + // external CUDA allocator. external_allocator_info = info.external_allocator_info; return cuda_provider_info->CreateExecutionProviderFactory(info)->CreateProvider(); } else { if (!Env::Default().GetEnvironmentVar("CUDA_PATH").empty()) { - ORT_THROW("CUDA_PATH is set but CUDA wasn't able to be loaded. Please install the correct version of CUDA and cuDNN as mentioned in the GPU requirements page (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), make sure they're in the PATH, and that your GPU is supported."); + ORT_THROW( + "CUDA_PATH is set but CUDA wasnt able to be loaded. Please install the correct version of CUDA and" + "cuDNN as mentioned in the GPU requirements page " + " (https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements), " + " make sure they're in the PATH, and that your GPU is supported."); } } } - LOGS_DEFAULT(WARNING) << "Failed to create " << type << ". Please reference https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements to ensure all dependencies are met."; + LOGS_DEFAULT(WARNING) << "Failed to create " + << type + << ". Please reference " + << "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements" + << "to ensure all dependencies are met."; #endif } else if (type == kRocmExecutionProvider) { #ifdef USE_ROCM diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 97330295e17ed..f506516442b1e 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -779,7 +779,7 @@ def main(): logger.error("fp16 is for GPU only") return - if args.precision == Precision.INT8 and args.use_gpu: + if args.precision == Precision.INT8 and args.use_gpu and args.provider != "migraphx": logger.error("int8 is for CPU only") return diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index e224507bc740e..65646a7286719 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -69,7 +69,9 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { OrtMIGraphXProviderOptions params{ 0, 0, - 0}; + 0, + 0, + nullptr}; return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); #else return nullptr;