Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to set CoreML EP flags from python #21434

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 57 additions & 60 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
#include "contrib_ops/cpu/aten_ops/aten_op_executor.h"
#endif

#if defined(USE_COREML)
#include "core/providers/coreml/coreml_provider_factory.h"
#endif

#include <pybind11/functional.h>

// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
Expand Down Expand Up @@ -1156,7 +1160,30 @@
#if !defined(__APPLE__)
LOGS_DEFAULT(WARNING) << "CoreML execution provider can only be used to generate ORT format model in this build.";
#endif
return onnxruntime::CoreMLProviderFactoryCreator::Create(0)->CreateProvider();
uint32_t coreml_flags = 0;

const auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
const ProviderOptions& options = it->second;
auto flags = options.find("flags");
if (flags != options.end()) {
const auto& flags_str = flags->second;

if (flags_str.find("COREML_FLAG_USE_CPU_ONLY") != std::string::npos) {
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
coreml_flags |= COREMLFlags::COREML_FLAG_USE_CPU_ONLY;
}

if (flags_str.find("COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES;
}

if (flags_str.find("COREML_FLAG_CREATE_MLPROGRAM") != std::string::npos) {
coreml_flags |= COREMLFlags::COREML_FLAG_CREATE_MLPROGRAM;
}
}
}

return onnxruntime::CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider();
#endif
} else if (type == kXnnpackExecutionProvider) {
#if defined(USE_XNNPACK)
Expand Down Expand Up @@ -1419,7 +1446,7 @@
ORT_UNUSED_PARAMETER(algo);
ORT_THROW("set_cudnn_conv_algo_search is not supported in ROCM");
#else
cudnn_conv_algo_search = algo;
cudnn_conv_algo_search = algo;
#endif
});
// TODO remove deprecated global config
Expand All @@ -1430,7 +1457,7 @@
ORT_UNUSED_PARAMETER(use_single_stream);
ORT_THROW("set_do_copy_in_default_stream is not supported in ROCM");
#else
do_copy_in_default_stream = use_single_stream;
do_copy_in_default_stream = use_single_stream;
#endif
});
// TODO remove deprecated global config
Expand Down Expand Up @@ -1795,10 +1822,10 @@
}
ORT_THROW_IF_ERROR(options->value.AddExternalInitializers(names_ptrs, values_ptrs));
#else
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(names);
ORT_UNUSED_PARAMETER(ort_values);
ORT_THROW("External initializers are not supported in this build.");
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(names);
ORT_UNUSED_PARAMETER(ort_values);
ORT_THROW("External initializers are not supported in this build.");
#endif
});

Expand Down Expand Up @@ -1860,8 +1887,7 @@
return *(na.Type());
},
"node type")
.def(
"__str__", [](const onnxruntime::NodeArg& na) -> std::string {
.def("__str__", [](const onnxruntime::NodeArg& na) -> std::string {
std::ostringstream res;
res << "NodeArg(name='" << na.Name() << "', type='" << *(na.Type()) << "', shape=";
auto shape = na.Shape();
Expand All @@ -1887,11 +1913,8 @@
}
res << ")";

return std::string(res.str());
},
"converts the node into a readable string")
.def_property_readonly(
"shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
return std::string(res.str()); }, "converts the node into a readable string")
.def_property_readonly("shape", [](const onnxruntime::NodeArg& na) -> std::vector<py::object> {
auto shape = na.Shape();
std::vector<py::object> arr;
if (shape == nullptr || shape->dim_size() == 0) {
Expand All @@ -1908,9 +1931,7 @@
arr[i] = py::none();
}
}
return arr;
},
"node shape (assuming the node holds a tensor)");
return arr; }, "node shape (assuming the node holds a tensor)");

py::class_<SessionObjectInitializer> sessionObjectInitializer(m, "SessionObjectInitializer");
py::class_<PyInferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
Expand Down Expand Up @@ -2101,51 +2122,28 @@
.def_property_readonly("get_profiling_start_time_ns", [](const PyInferenceSession* sess) -> uint64_t {
return sess->GetSessionHandle()->GetProfiling().GetStartTimeNs();
})
.def(
"get_providers", [](const PyInferenceSession* sess) -> const std::vector<std::string>& {
return sess->GetSessionHandle()->GetRegisteredProviderTypes();
},
py::return_value_policy::reference_internal)
.def(
"get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& {
return sess->GetSessionHandle()->GetAllProviderOptions();
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"session_options", [](const PyInferenceSession* sess) -> PySessionOptions* {
.def("get_providers", [](const PyInferenceSession* sess) -> const std::vector<std::string>& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal)

Check warning on line 2125 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_state.cc:2125: Lines should be <= 120 characters long [whitespace/line_length] [2]
.def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal)

Check warning on line 2126 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_state.cc:2126: Lines should be <= 120 characters long [whitespace/line_length] [2]
.def_property_readonly("session_options", [](const PyInferenceSession* sess) -> PySessionOptions* {
auto session_options = std::make_unique<PySessionOptions>();
session_options->value = sess->GetSessionHandle()->GetSessionOptions();
return session_options.release();
},
py::return_value_policy::take_ownership)
.def_property_readonly(
"inputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
return session_options.release(); }, py::return_value_policy::take_ownership)
.def_property_readonly("inputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {

Check warning on line 2131 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_state.cc:2131: Lines should be <= 120 characters long [whitespace/line_length] [2]
auto res = sess->GetSessionHandle()->GetModelInputs();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"outputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("outputs_meta", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {

Check warning on line 2135 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_state.cc:2135: Lines should be <= 120 characters long [whitespace/line_length] [2]
auto res = sess->GetSessionHandle()->GetModelOutputs();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("overridable_initializers", [](const PyInferenceSession* sess) -> const std::vector<const onnxruntime::NodeArg*>& {

Check warning on line 2139 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/python/onnxruntime_pybind_state.cc:2139: Lines should be <= 120 characters long [whitespace/line_length] [2]
auto res = sess->GetSessionHandle()->GetOverridableInitializers();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
.def_property_readonly(
"model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& {
return *(res.second); }, py::return_value_policy::reference_internal)
.def_property_readonly("model_meta", [](const PyInferenceSession* sess) -> const onnxruntime::ModelMetadata& {
auto res = sess->GetSessionHandle()->GetModelMetadata();
OrtPybindThrowIfError(res.first);
return *(res.second);
},
py::return_value_policy::reference_internal)
return *(res.second); }, py::return_value_policy::reference_internal)
.def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void {
Status status;
// release GIL to allow multiple python threads to invoke Run() in parallel.
Expand All @@ -2155,8 +2153,7 @@
else
status = sess->GetSessionHandle()->Run(*run_options, *io_binding.Get());
if (!status.IsOK())
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
})
throw std::runtime_error("Error in execution: " + status.ErrorMessage()); })
.def("get_tuning_results", [](PyInferenceSession* sess) -> py::list {
#if !defined(ORT_MINIMAL_BUILD)
auto results = sess->GetSessionHandle()->GetTuningResults();
Expand All @@ -2171,8 +2168,8 @@

return ret;
#else
ORT_UNUSED_PARAMETER(sess);
ORT_THROW("TunableOp and get_tuning_results are not supported in this build.");
ORT_UNUSED_PARAMETER(sess);
ORT_THROW("TunableOp and get_tuning_results are not supported in this build.");
#endif
})
.def("set_tuning_results", [](PyInferenceSession* sess, py::list results, bool error_on_invalid) -> void {
Expand Down Expand Up @@ -2203,10 +2200,10 @@
throw std::runtime_error("Error in execution: " + status.ErrorMessage());
}
#else
ORT_UNUSED_PARAMETER(sess);
ORT_UNUSED_PARAMETER(results);
ORT_UNUSED_PARAMETER(error_on_invalid);
ORT_THROW("TunableOp and set_tuning_results are not supported in this build.");
ORT_UNUSED_PARAMETER(sess);
ORT_UNUSED_PARAMETER(results);
ORT_UNUSED_PARAMETER(error_on_invalid);
ORT_THROW("TunableOp and set_tuning_results are not supported in this build.");
#endif
});

Expand Down
Loading