Skip to content

Commit

Permalink
Support onnx data types (bfloat16, float8) in python I/O binding APIs (
Browse files Browse the repository at this point in the history
…#22306)

### Description
(1) Support onnx data types in python APIs:
* IOBinding.bind_input
* IOBinding.bind_output
* ortvalue_from_shape_and_type

(2) Add unit tests, which serves an example of running BFloat16 or
Float8 models in Python.

Other minor changes:
(3) replace deprecated NP_TYPE_TO_TENSOR_TYPE by helper API.
(4) Rename ortvalue_from_numpy_with_onnxtype to
ortvalue_from_numpy_with_onnx_type.

The integer of onnx element type can be found in
(https://onnx.ai/onnx/api/mapping.html). Note that FLOAT4E2M1 is not
supported yet.

### Motivation and Context

Current python API does not support Bfloat16 and float8 (FLOAT8E4M3FN,
FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ) types, and other new data
types like INT4, UInt4 etc.

This removes the limitation.

#13001
#20481
#20578
  • Loading branch information
tianleiwu authored Oct 5, 2024
1 parent 96a1ce1 commit b5ef855
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 78 deletions.
32 changes: 22 additions & 10 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def bind_input(self, name, device_type, device_id, element_type, shape, buffer_p
:param name: input name
:param device_type: e.g. cpu, cuda, cann
:param device_id: device id, e.g. 0
:param element_type: input element type
:param element_type: input element type. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16)
:param shape: input shape
:param buffer_ptr: memory pointer to input data
"""
Expand Down Expand Up @@ -641,7 +641,7 @@ def bind_output(
:param name: output name
:param device_type: e.g. cpu, cuda, cann, cpu by default
:param device_id: device id, e.g. 0
:param element_type: output element type
:param element_type: output element type. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16)
:param shape: output shape
:param buffer_ptr: memory pointer to output data
"""
Expand Down Expand Up @@ -758,31 +758,43 @@ def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0):
)

@staticmethod
def ortvalue_from_numpy_with_onnxtype(data: Sequence[int], onnx_element_type: int):
def ortvalue_from_numpy_with_onnx_type(data, onnx_element_type: int):
"""
This method creates an instance of OrtValue on top of the numpy array
This method creates an instance of OrtValue on top of the numpy array.
No data copy is made and the lifespan of the resulting OrtValue should never
exceed the lifespan of bytes object. The API attempts to reinterpret
the data type which is expected to be the same size. This is useful
when we want to use an ONNX data type that is not supported by numpy.
:param data: numpy array.
:param data: numpy.ndarray.
:param onnx_elemenet_type: a valid onnx TensorProto::DataType enum value
"""
return OrtValue(C.OrtValue.ortvalue_from_numpy_with_onnxtype(data, onnx_element_type), data)
return OrtValue(C.OrtValue.ortvalue_from_numpy_with_onnx_type(data, onnx_element_type), data)

@staticmethod
def ortvalue_from_shape_and_type(shape=None, element_type=None, device_type="cpu", device_id=0):
def ortvalue_from_shape_and_type(shape, element_type, device_type: str = "cpu", device_id: int = 0):
"""
Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type
:param shape: List of integers indicating the shape of the OrtValue
:param element_type: The data type of the elements in the OrtValue (numpy type)
:param element_type: The data type of the elements. It can be either numpy type (like numpy.float32) or an integer for onnx type (like onnx.TensorProto.BFLOAT16).
:param device_type: e.g. cpu, cuda, cann, cpu by default
:param device_id: device id, e.g. 0
"""
if shape is None or element_type is None:
raise ValueError("`element_type` and `shape` are to be provided if pre-allocated memory is provided")
# Integer for onnx element type (see https://onnx.ai/onnx/api/mapping.html).
# This is helpful for some data type (like TensorProto.BFLOAT16) that is not available in numpy.
if isinstance(element_type, int):
return OrtValue(
C.OrtValue.ortvalue_from_shape_and_onnx_type(
shape,
element_type,
C.OrtDevice(
get_ort_device_type(device_type, device_id),
C.OrtDevice.default_memory(),
device_id,
),
)
)

return OrtValue(
C.OrtValue.ortvalue_from_shape_and_type(
Expand Down
80 changes: 52 additions & 28 deletions onnxruntime/python/onnxruntime_pybind_iobinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,39 @@ namespace python {

namespace py = pybind11;

namespace {
void BindOutput(SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device,
MLDataType element_type, const std::vector<int64_t>& shape, int64_t data_ptr) {
ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid");
InferenceSession* sess = io_binding->GetInferenceSession();
auto px = sess->GetModelOutputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}

// For now, limit binding support to only non-string Tensors
const auto& def_list = *px.second;
onnx::TypeProto type_proto;
if (!CheckIfTensor(def_list, name, type_proto)) {
throw std::runtime_error("Only binding Tensors is currently supported");
}

ORT_ENFORCE(utils::HasTensorType(type_proto) && utils::HasElemType(type_proto.tensor_type()));
if (type_proto.tensor_type().elem_type() == onnx::TensorProto::STRING) {
throw std::runtime_error("Only binding non-string Tensors is currently supported");
}

OrtValue ml_value;
OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id());
Tensor::InitOrtValue(element_type, gsl::make_span(shape), reinterpret_cast<void*>(data_ptr), info, ml_value);

auto status = io_binding->Get()->BindOutput(name, ml_value);
if (!status.IsOK()) {
throw std::runtime_error("Error when binding output: " + status.ErrorMessage());
}
}
} // namespace

void addIoBindingMethods(pybind11::module& m) {
py::class_<SessionIOBinding> session_io_binding(m, "SessionIOBinding");
session_io_binding
Expand Down Expand Up @@ -58,6 +91,18 @@ void addIoBindingMethods(pybind11::module& m) {
}
})
// This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo
.def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, int32_t element_type, const std::vector<int64_t>& shape, int64_t data_ptr) -> void {
auto ml_type = OnnxTypeToOnnxRuntimeTensorType(element_type);
OrtValue ml_value;
OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id());
Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast<void*>(data_ptr), info, ml_value);

auto status = io_binding->Get()->BindInput(name, ml_value);
if (!status.IsOK()) {
throw std::runtime_error("Error when binding input: " + status.ErrorMessage());
}
})
// This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo
.def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector<int64_t>& shape, int64_t data_ptr) -> void {
PyArray_Descr* dtype;
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
Expand Down Expand Up @@ -90,44 +135,23 @@ void addIoBindingMethods(pybind11::module& m) {
throw std::runtime_error("Error when synchronizing bound inputs: " + status.ErrorMessage());
}
})
// This binds output to a pre-allocated memory as a Tensor.
// The element type is onnx type , or key in onnx.mapping.TENSOR_TYPE_MAP (https://onnx.ai/onnx/api/mapping.html)
.def("bind_output", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, int32_t element_type, const std::vector<int64_t>& shape, int64_t data_ptr) -> void {
MLDataType ml_type = OnnxTypeToOnnxRuntimeTensorType(element_type);
BindOutput(io_binding, name, device, ml_type, shape, data_ptr);
})
// This binds output to a pre-allocated memory as a Tensor
.def("bind_output", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector<int64_t>& shape, int64_t data_ptr) -> void {
ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid");

InferenceSession* sess = io_binding->GetInferenceSession();
auto px = sess->GetModelOutputs();
if (!px.first.IsOK() || !px.second) {
throw std::runtime_error("Either failed to get model inputs from the session object or the input def list was null");
}

// For now, limit binding support to only non-string Tensors
const auto& def_list = *px.second;
onnx::TypeProto type_proto;
if (!CheckIfTensor(def_list, name, type_proto)) {
throw std::runtime_error("Only binding Tensors is currently supported");
}

ORT_ENFORCE(utils::HasTensorType(type_proto) && utils::HasElemType(type_proto.tensor_type()));
if (type_proto.tensor_type().elem_type() == onnx::TensorProto::STRING) {
throw std::runtime_error("Only binding non-string Tensors is currently supported");
}

PyArray_Descr* dtype;
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
throw std::runtime_error("Not a valid numpy type");
}
int type_num = dtype->type_num;
Py_DECREF(dtype);

OrtMemoryInfo info(GetDeviceName(device), OrtDeviceAllocator, device, device.Id());
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
OrtValue ml_value;
Tensor::InitOrtValue(ml_type, gsl::make_span(shape), reinterpret_cast<void*>(data_ptr), info, ml_value);

auto status = io_binding->Get()->BindOutput(name, ml_value);
if (!status.IsOK()) {
throw std::runtime_error("Error when binding output: " + status.ErrorMessage());
}
BindOutput(io_binding, name, device, ml_type, shape, data_ptr);
})
// This binds output to a device. Meaning that the output OrtValue must be allocated on a specific device.
.def("bind_output", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device) -> void {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,10 @@ MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type) {
}
}

MLDataType OnnxTypeToOnnxRuntimeTensorType(int onnx_element_type) {
return DataTypeImpl::TensorTypeFromONNXEnum(onnx_element_type)->GetElementType();
}

// This is a one time use, ad-hoc allocator that allows Tensors to take ownership of
// python array objects and use the underlying memory directly and
// properly deallocated them when they are done.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ int OnnxRuntimeTensorToNumpyType(const DataTypeImpl* tensor_type);

MLDataType NumpyTypeToOnnxRuntimeTensorType(int numpy_type);

MLDataType OnnxTypeToOnnxRuntimeTensorType(int onnx_element_type);

using MemCpyFunc = void (*)(void*, const void*, size_t);

using DataTransferAlternative = std::variant<const DataTransferManager*, MemCpyFunc>;
Expand Down
85 changes: 51 additions & 34 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,42 @@ namespace python {

namespace py = pybind11;

namespace {
std::unique_ptr<OrtValue> OrtValueFromShapeAndType(const std::vector<int64_t>& shape,
MLDataType element_type,
const OrtDevice& device) {
AllocatorPtr allocator;
if (strcmp(GetDeviceName(device), CPU) == 0) {
allocator = GetAllocator();
} else if (strcmp(GetDeviceName(device), CUDA) == 0) {
#ifdef USE_CUDA
if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
allocator = GetCudaAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (strcmp(GetDeviceName(device), DML) == 0) {
#if USE_DML
allocator = GetDmlAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the DirectML device using this package of OnnxRuntime. "
"Please use the DirectML package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
}

auto ml_value = std::make_unique<OrtValue>();
Tensor::InitOrtValue(element_type, gsl::make_span(shape), std::move(allocator), *ml_value);
return ml_value;
}
} // namespace

void addOrtValueMethods(pybind11::module& m) {
py::class_<OrtValue> ortvalue_binding(m, "OrtValue");
ortvalue_binding
Expand Down Expand Up @@ -144,13 +180,12 @@ void addOrtValueMethods(pybind11::module& m) {
})
// Create an ortvalue value on top of the numpy array, but interpret the data
// as a different type with the same element size.
.def_static("ortvalue_from_numpy_with_onnxtype", [](py::array& data, int32_t onnx_element_type) -> std::unique_ptr<OrtValue> {
.def_static("ortvalue_from_numpy_with_onnx_type", [](py::array& data, int32_t onnx_element_type) -> std::unique_ptr<OrtValue> {
if (!ONNX_NAMESPACE::TensorProto_DataType_IsValid(onnx_element_type)) {
ORT_THROW("Not a valid ONNX Tensor data type: ", onnx_element_type);
}

const auto element_type = DataTypeImpl::TensorTypeFromONNXEnum(onnx_element_type)
->GetElementType();
const auto element_type = OnnxTypeToOnnxRuntimeTensorType(onnx_element_type);

const auto element_size = element_type->Size();
if (narrow<size_t>(data.itemsize()) != element_size) {
Expand All @@ -164,11 +199,11 @@ void addOrtValueMethods(pybind11::module& m) {
const_cast<void*>(data.data()), cpu_allocator->Info(), *ort_value);
return ort_value;
})
// Factory method to create an OrtValue (Tensor) from the given shape and element type with memory on the specified device
// Factory method to create an OrtValue from the given shape and numpy element type on the specified device.
// The memory is left uninitialized
.def_static("ortvalue_from_shape_and_type", [](const std::vector<int64_t>& shape, py::object& element_type, const OrtDevice& device) {
.def_static("ortvalue_from_shape_and_type", [](const std::vector<int64_t>& shape, py::object& numpy_element_type, const OrtDevice& device) -> std::unique_ptr<OrtValue> {
PyArray_Descr* dtype;
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
if (!PyArray_DescrConverter(numpy_element_type.ptr(), &dtype)) {
throw std::runtime_error("Not a valid numpy type");
}

Expand All @@ -179,36 +214,18 @@ void addOrtValueMethods(pybind11::module& m) {
throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays");
}

AllocatorPtr allocator;
if (strcmp(GetDeviceName(device), CPU) == 0) {
allocator = GetAllocator();
} else if (strcmp(GetDeviceName(device), CUDA) == 0) {
#ifdef USE_CUDA
if (!IsCudaDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) {
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine.");
}
allocator = GetCudaAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (strcmp(GetDeviceName(device), DML) == 0) {
#if USE_DML
allocator = GetDmlAllocator(device.Id());
#else
throw std::runtime_error(
"Can't allocate memory on the DirectML device using this package of OnnxRuntime. "
"Please use the DirectML package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
auto element_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
return OrtValueFromShapeAndType(shape, element_type, device);
})
// Factory method to create an OrtValue from the given shape and onnx element type on the specified device.
// The memory is left uninitialized
.def_static("ortvalue_from_shape_and_onnx_type", [](const std::vector<int64_t>& shape, int32_t onnx_element_type, const OrtDevice& device) -> std::unique_ptr<OrtValue> {
if (onnx_element_type == onnx::TensorProto_DataType::TensorProto_DataType_STRING) {
throw std::runtime_error("Creation of OrtValues is currently only supported from non-string numpy arrays");
}

auto ml_value = std::make_unique<OrtValue>();
auto ml_type = NumpyTypeToOnnxRuntimeTensorType(type_num);
Tensor::InitOrtValue(ml_type, gsl::make_span(shape), std::move(allocator), *ml_value);
return ml_value;
auto element_type = OnnxTypeToOnnxRuntimeTensorType(onnx_element_type);
return OrtValueFromShapeAndType(shape, element_type, device);
})

#if !defined(DISABLE_SPARSE_TENSORS)
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,9 @@ def test_session_with_ortvalue_input(ortvalue):

# test ort_value creation on top of the bytes
float_tensor_data_type = 1 # TensorProto_DataType_FLOAT
ort_value_with_type = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(numpy_arr_input, float_tensor_data_type)
ort_value_with_type = onnxrt.OrtValue.ortvalue_from_numpy_with_onnx_type(
numpy_arr_input, float_tensor_data_type
)
self.assertTrue(ort_value_with_type.is_tensor())
self.assertEqual(float_tensor_data_type, ort_value_with_type.element_type())
self.assertEqual([3, 2], ort_value_with_type.shape())
Expand Down Expand Up @@ -1843,8 +1845,8 @@ def test_adater_export_read(self):
param_1 = np.array(val).astype(np.float32).reshape(5, 2)
param_2 = np.array(val).astype(np.int64).reshape(2, 5)

ort_val_1 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(param_1, float_data_type)
ort_val_2 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnxtype(param_2, int64_data_type)
ort_val_1 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnx_type(param_1, float_data_type)
ort_val_2 = onnxrt.OrtValue.ortvalue_from_numpy_with_onnx_type(param_2, int64_data_type)

params = {"param_1": ort_val_1, "param_2": ort_val_2}

Expand Down
Loading

0 comments on commit b5ef855

Please sign in to comment.