Skip to content

Commit

Permalink
[MIGraphX EP] Set External Data Path (#21598)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Changes to add in Set external data path for model weight files.
Additional fixes to ensure this compiles off the latest v1.19
Onnxruntime


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Separate weights used for larger models (like stable diffusion) is
motivation for this change set

---------

Co-authored-by: Jeff Daily <[email protected]>
Co-authored-by: Artur Wojcik <[email protected]>
Co-authored-by: Ted Themistokleous <[email protected]>
  • Loading branch information
4 people authored Aug 2, 2024
1 parent 54d6614 commit 45b7c41
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <iterator>
#include <unordered_map>
#include <set>
#include <filesystem>

#include "core/providers/shared_library/provider_api.h"
#define ORT_API_MANUAL_INIT
Expand Down Expand Up @@ -990,6 +991,7 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
std::string onnx_string_buffer;
model_proto->SerializeToString(onnx_string_buffer);
model_path_ = graph_viewer.ModelPath();

// dump onnx file if environment var is set
if (dump_model_ops_) {
Expand Down Expand Up @@ -1168,7 +1170,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
auto param_shapes = prog.get_parameter_shapes();

// Add all calibration data read in from int8 table
for (auto& [cal_key, cal_val] : dynamic_range_map) {
for (auto& [cal_key, cal_val] : dynamic_range_map_) {
auto cal_val_shape = migraphx::shape(migraphx_shape_float_type);
quant_params.add(cal_key.c_str(), migraphx::argument(cal_val_shape, static_cast<void*>(std::move(&cal_val))));
}
Expand Down Expand Up @@ -1217,7 +1219,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
*p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
map_no_input_shape_[context->node_name], fp16_enable_, int8_enable_,
int8_calibration_cache_available_, dynamic_range_map,
int8_calibration_cache_available_, dynamic_range_map_,
save_compiled_model_, save_compiled_path_,
load_compiled_model_, load_compiled_path_, dump_model_ops_};
*state = p.release();
Expand Down Expand Up @@ -1297,6 +1299,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
if (!input_shape_match) {
if (!load_precompiled_model(prog, load_compiled_model_, std::string{load_compiled_path_})) {
LOGS_DEFAULT(VERBOSE) << "No Input shapes mismatch detected. Recompiling" << std::endl;
cmp_options.set_external_data_path(model_path_.has_parent_path() ? model_path_.parent_path().string() : std::filesystem::current_path().string());
prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);

// Read in the calibration data and map it to an migraphx paramater map for the calibration ops
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <map>
#include <unordered_map>
#include <filesystem>

namespace onnxruntime {

Expand Down Expand Up @@ -91,7 +92,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
bool int8_calibration_cache_available_ = false;
bool int8_use_native_migraphx_calibration_table_ = false;
std::string calibration_cache_path_;
std::unordered_map<std::string, float> dynamic_range_map;
std::unordered_map<std::string, float> dynamic_range_map_;
bool save_compiled_model_ = false;
std::string save_compiled_path_;
bool load_compiled_model_ = false;
Expand All @@ -100,6 +101,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
migraphx::target t_;
OrtMutex mgx_mu_;
hipStream_t stream_ = nullptr;
mutable std::filesystem::path model_path_;

std::unordered_map<std::string, migraphx::program> map_progs_;
std::unordered_map<std::string, std::string> map_onnx_string_;
Expand Down

0 comments on commit 45b7c41

Please sign in to comment.