From c392fdeb1b442aa213b664edc55e079efb336540 Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Wed, 2 Aug 2023 10:15:34 -0700 Subject: [PATCH] RunAsync Python API (#16760) Implement python binding for RunAsync API. --------- Co-authored-by: Randy Shuai --- .../onnxruntime_inference_collection.py | 30 +++++ .../python/onnxruntime_pybind_state.cc | 126 ++++++++++++++++++ .../test/python/onnxruntime_test_python.py | 31 +++++ 3 files changed, 187 insertions(+) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 408e533da3f1f..b73fcbbff5456 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -228,6 +228,36 @@ def run(self, output_names, input_feed, run_options=None): return self._sess.run(output_names, input_feed, run_options) raise + def run_async(self, output_names, input_feed, callback, user_data, run_options=None): + """ + Compute the predictions asynchronously in a separate cxx thread from ort intra-op threadpool. + + :param output_names: name of the outputs + :param input_feed: dictionary ``{ input_name: input_value }`` + :param callback: python function that accept array of results, and a status string on error. + The callback will be invoked by a cxx thread from ort intra-op threadpool. + :param run_options: See :class:`onnxruntime.RunOptions`. + + :: + class MyData: + def __init__(self): + # ... + def save_results(self, results): + # ... + + def callback(results: np.ndarray, user_data: MyData, err: str) -> None: + if err: + print (err) + else: + # save results to user_data + + sess.run_async([output_name], {input_name: x}, callback) + """ + self._validate_input(list(input_feed.keys())) + if not output_names: + output_names = [output.name for output in self._outputs_meta] + return self._sess.run_async(output_names, input_feed, callback, user_data, run_options) + def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None): """ Compute the predictions. diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f81a143151fde..826c996c22d6e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -36,6 +36,8 @@ #include "contrib_ops/cpu/aten_ops/aten_op_executor.h" #endif +#include + // Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct, // GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses // GCC 4.x. @@ -74,6 +76,83 @@ static Env& platform_env = Env::Default(); #pragma warning(push) #endif +using PyCallback = std::function, py::object user_data, std::string)>; + +struct AsyncResource { + std::vector feeds; + std::vector feeds_raw; + + std::vector feed_names; + std::vector feed_names_raw; + + std::vector fetches_raw; + + std::vector fetch_names; + std::vector fetch_names_raw; + + RunOptions default_run_option; + PyCallback callback; + py::object user_data; + + void ReserveFeeds(size_t sz) { + feeds.reserve(sz); + feeds_raw.reserve(sz); + feed_names.reserve(sz); + feed_names_raw.reserve(sz); + } + + void ReserveFetches(size_t sz) { + fetches_raw.reserve(sz); + fetch_names.reserve(sz); + fetch_names_raw.reserve(sz); + } +}; + +void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) { + ORT_ENFORCE(user_data, "user data must not be NULL for callback in python"); + + auto invoke_callback = [&]() { + std::unique_ptr async_resource{reinterpret_cast(user_data)}; + Ort::Status status(ort_status); + + // return on error + if (!status.IsOK()) { + async_resource->callback({}, async_resource->user_data, status.GetErrorMessage()); + return; + } + + std::vector rfetch; + rfetch.reserve(num_outputs); + size_t pos = 0; + for (size_t ith = 0; ith < num_outputs; ++ith) { + const auto& fet = *outputs[ith]; + if (fet.IsAllocated()) { + if (fet.IsTensor()) { + rfetch.push_back(AddTensorAsPyObj(fet, nullptr, nullptr)); + } else if (fet.IsSparseTensor()) { + rfetch.push_back(GetPyObjectFromSparseTensor(pos, fet, nullptr)); + } else { + rfetch.push_back(AddNonTensorAsPyObj(fet, nullptr, nullptr)); + } + } else { + rfetch.push_back(py::none()); + } + ++pos; + } + async_resource->callback(rfetch, async_resource->user_data, ""); + }; + + if (PyGILState_Check()) { + invoke_callback(); + } else { + // acquire GIL to safely: + // 1) invoke python callback + // 2) create, manipulate, and destory python objects + py::gil_scoped_acquire acquire; + invoke_callback(); + } +} + template static py::object AddNonTensor(const OrtValue& val, const DataTransferManager* /*data_transfer_manager*/, @@ -1680,6 +1759,53 @@ including arg name, arg type (contains both type and shape).)pbdoc") } return rfetch; }) + .def("run_async", + [](PyInferenceSession* sess, + std::vector output_names, + std::map pyfeeds, + PyCallback callback, py::object user_data = {}, + RunOptions* run_options = nullptr) + -> void { + std::unique_ptr async_resource = std::make_unique(); + async_resource->callback = callback; + async_resource->user_data = user_data; + // prepare feeds + async_resource->ReserveFeeds(pyfeeds.size()); + for (auto feed : pyfeeds) { + if (!feed.second.is(py::none())) { + OrtValue ml_value; + auto px = sess->GetSessionHandle()->GetModelInputs(); + if (!px.first.IsOK() || !px.second) { + throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null"); + } + CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value); + ThrowIfPyErrOccured(); + async_resource->feeds.push_back(ml_value); + async_resource->feeds_raw.push_back(&async_resource->feeds.back()); + async_resource->feed_names.push_back(feed.first); + async_resource->feed_names_raw.push_back(async_resource->feed_names.back().c_str()); + } + } + // prepare fetches + async_resource->ReserveFetches(output_names.size()); + for (auto& output_name : output_names) { + async_resource->fetch_names.push_back(output_name); + async_resource->fetch_names_raw.push_back(async_resource->fetch_names.back().c_str()); + async_resource->fetches_raw.push_back({}); + } + const RunOptions* run_async_option = run_options ? run_options : &async_resource->default_run_option; + common::Status status = sess->GetSessionHandle()->RunAsync(run_async_option, + gsl::span(async_resource->feed_names_raw.data(), async_resource->feed_names_raw.size()), + gsl::span(async_resource->feeds_raw.data(), async_resource->feeds_raw.size()), + gsl::span(async_resource->fetch_names_raw.data(), async_resource->fetch_names_raw.size()), + gsl::span(async_resource->fetches_raw.data(), async_resource->fetches_raw.size()), + AsyncCallback, + async_resource.get()); + if (status.IsOK()) { + async_resource.release(); + } + OrtPybindThrowIfError(status); + }) /// This method accepts a dictionary of feeds (name -> OrtValue) and the list of output_names /// and returns a list of python objects representing OrtValues. Each name may represent either /// a Tensor, SparseTensor or a TensorSequence. diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e4522348c6e9e..e554d418667a1 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -586,6 +586,37 @@ def test_run_model(self): output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + def test_run_async(self): + event = threading.Event() + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + + class MyData: + def __init__(self, id): + self.__id = id + + def get_id(self): + return self.__id + + my_data = MyData(123456) + + def callback(res: np.ndarray, data: MyData, err: str) -> None: + self.assertEqual(len(err), 0) + self.assertEqual(len(res), 1) + self.assertEqual(data.get_id(), 123456) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + event.set() + + so = onnxrt.SessionOptions() + so.intra_op_num_threads = 2 + + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), so, providers=available_providers) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + sess.run_async(["Y"], {"X": x}, callback, my_data) + + event.wait(10) # timeout in 10 sec + self.assertTrue(event.is_set()) + def test_run_model_from_bytes(self): with open(get_name("mul_1.onnx"), "rb") as f: content = f.read()