Skip to content

Commit

Permalink
RunAsync Python API (#16760)
Browse files Browse the repository at this point in the history
Implement python binding for RunAsync API.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Aug 2, 2023
1 parent bd4d011 commit c392fde
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 0 deletions.
30 changes: 30 additions & 0 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,36 @@ def run(self, output_names, input_feed, run_options=None):
return self._sess.run(output_names, input_feed, run_options)
raise

def run_async(self, output_names, input_feed, callback, user_data, run_options=None):
"""
Compute the predictions asynchronously in a separate cxx thread from ort intra-op threadpool.
:param output_names: name of the outputs
:param input_feed: dictionary ``{ input_name: input_value }``
:param callback: python function that accept array of results, and a status string on error.
The callback will be invoked by a cxx thread from ort intra-op threadpool.
:param run_options: See :class:`onnxruntime.RunOptions`.
::
class MyData:
def __init__(self):
# ...
def save_results(self, results):
# ...
def callback(results: np.ndarray, user_data: MyData, err: str) -> None:
if err:
print (err)
else:
# save results to user_data
sess.run_async([output_name], {input_name: x}, callback)
"""
self._validate_input(list(input_feed.keys()))
if not output_names:
output_names = [output.name for output in self._outputs_meta]
return self._sess.run_async(output_names, input_feed, callback, user_data, run_options)

def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=None):
"""
Compute the predictions.
Expand Down
126 changes: 126 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "contrib_ops/cpu/aten_ops/aten_op_executor.h"
#endif

#include <pybind11/functional.h>

// Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct,
// GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses
// GCC 4.x.
Expand Down Expand Up @@ -74,6 +76,83 @@ static Env& platform_env = Env::Default();
#pragma warning(push)
#endif

using PyCallback = std::function<void(std::vector<py::object>, py::object user_data, std::string)>;

struct AsyncResource {
std::vector<OrtValue> feeds;
std::vector<const OrtValue*> feeds_raw;

std::vector<std::string> feed_names;
std::vector<const char*> feed_names_raw;

std::vector<OrtValue*> fetches_raw;

std::vector<std::string> fetch_names;
std::vector<const char*> fetch_names_raw;

RunOptions default_run_option;
PyCallback callback;
py::object user_data;

void ReserveFeeds(size_t sz) {
feeds.reserve(sz);
feeds_raw.reserve(sz);
feed_names.reserve(sz);
feed_names_raw.reserve(sz);
}

void ReserveFetches(size_t sz) {
fetches_raw.reserve(sz);
fetch_names.reserve(sz);
fetch_names_raw.reserve(sz);
}
};

void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) {
ORT_ENFORCE(user_data, "user data must not be NULL for callback in python");

auto invoke_callback = [&]() {
std::unique_ptr<AsyncResource> async_resource{reinterpret_cast<AsyncResource*>(user_data)};
Ort::Status status(ort_status);

// return on error
if (!status.IsOK()) {
async_resource->callback({}, async_resource->user_data, status.GetErrorMessage());
return;
}

std::vector<py::object> rfetch;
rfetch.reserve(num_outputs);
size_t pos = 0;
for (size_t ith = 0; ith < num_outputs; ++ith) {
const auto& fet = *outputs[ith];
if (fet.IsAllocated()) {
if (fet.IsTensor()) {
rfetch.push_back(AddTensorAsPyObj(fet, nullptr, nullptr));
} else if (fet.IsSparseTensor()) {
rfetch.push_back(GetPyObjectFromSparseTensor(pos, fet, nullptr));
} else {
rfetch.push_back(AddNonTensorAsPyObj(fet, nullptr, nullptr));
}
} else {
rfetch.push_back(py::none());
}
++pos;
}
async_resource->callback(rfetch, async_resource->user_data, "");
};

if (PyGILState_Check()) {
invoke_callback();
} else {
// acquire GIL to safely:
// 1) invoke python callback
// 2) create, manipulate, and destory python objects
py::gil_scoped_acquire acquire;
invoke_callback();
}
}

template <typename T>
static py::object AddNonTensor(const OrtValue& val,
const DataTransferManager* /*data_transfer_manager*/,
Expand Down Expand Up @@ -1680,6 +1759,53 @@ including arg name, arg type (contains both type and shape).)pbdoc")
}
return rfetch;
})
.def("run_async",
[](PyInferenceSession* sess,
std::vector<std::string> output_names,
std::map<std::string, py::object> pyfeeds,
PyCallback callback, py::object user_data = {},
RunOptions* run_options = nullptr)
-> void {
std::unique_ptr<AsyncResource> async_resource = std::make_unique<AsyncResource>();
async_resource->callback = callback;
async_resource->user_data = user_data;
// prepare feeds
async_resource->ReserveFeeds(pyfeeds.size());
for (auto feed : pyfeeds) {
if (!feed.second.is(py::none())) {
OrtValue ml_value;
auto px = sess->GetSessionHandle()->GetModelInputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}
CreateGenericMLValue(px.second, GetAllocator(), feed.first, feed.second, &ml_value);
ThrowIfPyErrOccured();
async_resource->feeds.push_back(ml_value);
async_resource->feeds_raw.push_back(&async_resource->feeds.back());
async_resource->feed_names.push_back(feed.first);
async_resource->feed_names_raw.push_back(async_resource->feed_names.back().c_str());
}
}
// prepare fetches
async_resource->ReserveFetches(output_names.size());
for (auto& output_name : output_names) {
async_resource->fetch_names.push_back(output_name);
async_resource->fetch_names_raw.push_back(async_resource->fetch_names.back().c_str());
async_resource->fetches_raw.push_back({});
}
const RunOptions* run_async_option = run_options ? run_options : &async_resource->default_run_option;
common::Status status = sess->GetSessionHandle()->RunAsync(run_async_option,
gsl::span(async_resource->feed_names_raw.data(), async_resource->feed_names_raw.size()),
gsl::span(async_resource->feeds_raw.data(), async_resource->feeds_raw.size()),
gsl::span(async_resource->fetch_names_raw.data(), async_resource->fetch_names_raw.size()),
gsl::span(async_resource->fetches_raw.data(), async_resource->fetches_raw.size()),
AsyncCallback,
async_resource.get());
if (status.IsOK()) {
async_resource.release();
}
OrtPybindThrowIfError(status);
})
/// This method accepts a dictionary of feeds (name -> OrtValue) and the list of output_names
/// and returns a list of python objects representing OrtValues. Each name may represent either
/// a Tensor, SparseTensor or a TensorSequence.
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,37 @@ def test_run_model(self):
output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32)
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)

def test_run_async(self):
event = threading.Event()
output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32)

class MyData:
def __init__(self, id):
self.__id = id

def get_id(self):
return self.__id

my_data = MyData(123456)

def callback(res: np.ndarray, data: MyData, err: str) -> None:
self.assertEqual(len(err), 0)
self.assertEqual(len(res), 1)
self.assertEqual(data.get_id(), 123456)
np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08)
event.set()

so = onnxrt.SessionOptions()
so.intra_op_num_threads = 2

sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), so, providers=available_providers)

x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
sess.run_async(["Y"], {"X": x}, callback, my_data)

event.wait(10) # timeout in 10 sec
self.assertTrue(event.is_set())

def test_run_model_from_bytes(self):
with open(get_name("mul_1.onnx"), "rb") as f:
content = f.read()
Expand Down

0 comments on commit c392fde

Please sign in to comment.