-
Notifications
You must be signed in to change notification settings - Fork 335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom Op runtime wrapper example #150
Draft
adrianlizarraga
wants to merge
7
commits into
main
Choose a base branch
from
adrianl/opwrapper_ep
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 3 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
24c6b70
Initial OpWrapper EP example
adrianlizarraga e091281
Print output probabilities
adrianlizarraga db22dd2
Parse command-line args
adrianlizarraga 0041618
Feed in actual image file
adrianlizarraga 4fad275
Update example to use re-implementation of custom op wrappers (no EP)
adrianlizarraga a270033
Use custom_op helper that adds session configs
adrianlizarraga 6073f0d
Use simpler mnist model
adrianlizarraga File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
cmake_minimum_required(VERSION 3.13) | ||
|
||
project(opwrapper_ov_test) | ||
|
||
option(ONNXRUNTIME_ROOTDIR "onnxruntime root directory") | ||
|
||
if(NOT ONNXRUNTIME_ROOTDIR) | ||
if(WIN32) | ||
set(ONNXRUNTIME_ROOTDIR "C:/Program Files (x86)/onnxruntime") | ||
else() | ||
include_directories("/usr/local/include/onnxruntime") | ||
endif() | ||
endif() | ||
|
||
# opwrapper_provider_factory.h contains EP-specific APIs. | ||
include(CheckIncludeFileCXX) | ||
set(CMAKE_REQUIRED_INCLUDES "${ONNXRUNTIME_ROOTDIR}/include") | ||
CHECK_INCLUDE_FILE_CXX(opwrapper_provider_factory.h HAVE_OPWRAPPER_PROVIDER_FACTORY_H) | ||
|
||
include(CMakePrintHelpers) | ||
cmake_print_variables(ONNXRUNTIME_ROOTDIR) | ||
|
||
set(CMAKE_CXX_STANDARD 17) | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) | ||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) | ||
|
||
include_directories( | ||
${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/session/ | ||
${ONNXRUNTIME_ROOTDIR}/include/onnxruntime/core/providers/opwrapper/ | ||
) | ||
link_directories("${ONNXRUNTIME_ROOTDIR}/lib") | ||
|
||
add_executable(opwrapper_ov_test main.cpp) | ||
|
||
if(HAVE_OPWRAPPER_PROVIDER_FACTORY_H) | ||
target_compile_definitions(opwrapper_ov_test PRIVATE -DHAVE_OPWRAPPER_PROVIDER_FACTORY_H) | ||
endif() | ||
|
||
target_link_libraries(opwrapper_ov_test onnxruntime) | ||
|
||
add_subdirectory(custom_op) | ||
|
||
if(WIN32) | ||
# Copy onnxruntime.dll into the executable's directory. | ||
add_custom_command(TARGET opwrapper_ov_test POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy_if_different | ||
"${ONNXRUNTIME_ROOTDIR}/bin/onnxruntime.dll" | ||
$<TARGET_FILE_DIR:opwrapper_ov_test>) | ||
endif() | ||
|
||
# TODO: Run a script that 1) downloads OpenVINO model and | ||
# 2) runs a python program to create the onnx model wrapper. | ||
# Copy ONNX model to executable's directory. | ||
add_custom_command(TARGET opwrapper_ov_test POST_BUILD | ||
COMMAND ${CMAKE_COMMAND} -E copy_if_different | ||
"${CMAKE_SOURCE_DIR}/custom_op_ov_ep_wrapper.onnx" | ||
"${CMAKE_BINARY_DIR}/data/custom_op_ov_ep_wrapper.onnx") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
find_package(OpenVINO REQUIRED) | ||
|
||
add_library(openvino_wrapper MODULE custom_op.cpp openvino_wrapper.cpp) | ||
target_link_libraries(openvino_wrapper PRIVATE openvino::runtime) | ||
|
||
if(UNIX) | ||
if (APPLE) | ||
set(OPENVINO_WRAPPER_LIB_LINK_FLAG "-Xlinker -dead_strip") | ||
else() | ||
set(OPENVINO_WRAPPER_LIB_LINK_FLAG "-Xlinker --version-script=${CMAKE_CURRENT_SOURCE_DIR}/custom_op.lds -Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") | ||
endif() | ||
else() | ||
set(OPENVINO_WRAPPER_LIB_LINK_FLAG "-DEF:${CMAKE_CURRENT_SOURCE_DIR}/custom_op.def") | ||
endif() | ||
|
||
set_property(TARGET openvino_wrapper APPEND_STRING PROPERTY LINK_FLAGS ${OPENVINO_WRAPPER_LIB_LINK_FLAG}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include <mutex> | ||
#include <vector> | ||
|
||
#include "custom_op.h" | ||
#include "openvino_wrapper.h" | ||
|
||
static const char* c_OpDomain = "ai.onnx.contrib"; | ||
|
||
static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { | ||
static std::vector<Ort::CustomOpDomain> ort_custom_op_domain_container; | ||
static std::mutex ort_custom_op_domain_mutex; | ||
std::lock_guard<std::mutex> lock(ort_custom_op_domain_mutex); | ||
ort_custom_op_domain_container.push_back(std::move(domain)); | ||
} | ||
|
||
OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) { | ||
|
||
// Allow use of Ort::GetApi() and Ort::OpWrapper::GetApi() in C++ ORT api implementations. | ||
Ort::OpWrapper::InitApi(api_base->GetApi(ORT_API_VERSION)); | ||
|
||
static CustomOpOpenVINO c_CustomOpOpenVINO; | ||
|
||
OrtStatus* result = nullptr; | ||
|
||
try { | ||
Ort::CustomOpDomain domain{c_OpDomain}; | ||
domain.Add(&c_CustomOpOpenVINO); | ||
|
||
Ort::Unowned<Ort::SessionOptions> session_options(options); | ||
session_options.Add(domain); | ||
AddOrtCustomOpDomainToContainer(std::move(domain)); | ||
|
||
} catch (const Ort::Exception& e) { | ||
result = Ort::GetApi().CreateStatus(e.GetOrtErrorCode(), e.what()); | ||
} catch(const std::exception& e) { | ||
result = Ort::GetApi().CreateStatus(ORT_FAIL, e.what()); | ||
} | ||
|
||
return result; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
LIBRARY "custom_op_library.dll" | ||
EXPORTS | ||
RegisterCustomOps @1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <onnxruntime_c_api.h> | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
VERS_1.0.0 { | ||
global: | ||
RegisterCustomOps; | ||
local: | ||
*; | ||
}; |
227 changes: 227 additions & 0 deletions
227
c_cxx/OpWrapper_EP/OpenVINO/custom_op/openvino_wrapper.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "openvino_wrapper.h" | ||
|
||
#include <iostream> | ||
#include <cassert> | ||
|
||
static ov::element::Type ConvertONNXToOVType(ONNXTensorElementDataType onnx_type) { | ||
switch (onnx_type) { | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: | ||
return ov::element::f32; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: | ||
return ov::element::u8; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: | ||
return ov::element::i8; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: | ||
return ov::element::u16; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: | ||
return ov::element::i16; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: | ||
return ov::element::i32; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: | ||
return ov::element::i64; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: | ||
return ov::element::boolean; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: | ||
return ov::element::f16; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: | ||
return ov::element::f64; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: | ||
return ov::element::u32; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: | ||
return ov::element::u64; | ||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: | ||
return ov::element::bf16; | ||
default: | ||
return ov::element::undefined; | ||
} | ||
} | ||
|
||
static bool AreShapesEqual(const std::vector<int64_t>& ort_shape, const ov::Shape& ov_shape) { | ||
if (ort_shape.size() != ov_shape.size()) { | ||
return false; | ||
} | ||
|
||
const size_t num_dims = ort_shape.size(); | ||
|
||
for (size_t i = 0; i < num_dims; ++i) { | ||
if (static_cast<decltype(ov_shape[i])>(ort_shape[i]) != ov_shape[i]) { | ||
return false; | ||
} | ||
} | ||
|
||
return true; | ||
} | ||
|
||
static bool AreIONodesEqual(OrtAllocator* allocator, const Ort::NodeArg& ort_node, | ||
const ov::Output<ov::Node>& ov_node) { | ||
// Check name | ||
auto ort_name = ort_node.GetName(allocator); | ||
std::string ov_name = ov_node.get_any_name(); | ||
if (std::strncmp(ort_name.first.get(), ov_name.c_str(), ort_name.second) != 0) { | ||
return false; | ||
} | ||
|
||
Ort::TypeInfo type_info = ort_node.GetTypeInfo(); | ||
Ort::Unowned<Ort::TensorTypeAndShapeInfo> type_shape_info = type_info.GetTensorTypeAndShapeInfo(); | ||
|
||
// Check element type. | ||
ov::element::Type ort_elem_type = ConvertONNXToOVType(type_shape_info.GetElementType()); | ||
ov::element::Type ov_elem_type = ov_node.get_element_type(); | ||
if (ort_elem_type != ov_elem_type) { | ||
return false; | ||
} | ||
|
||
// Check shape. | ||
std::vector<int64_t> ort_shape = type_shape_info.GetShape(); | ||
const ov::Shape& ov_shape = ov_node.get_shape(); | ||
if (!AreShapesEqual(ort_shape, ov_shape)) { | ||
return false; | ||
} | ||
|
||
return true; | ||
} | ||
|
||
static bool ValidateInputsAndOutputs(const Ort::KernelInfo& kinfo, const ov::OutputVector& ov_inputs, | ||
const ov::OutputVector& ov_outputs) { | ||
const size_t num_inputs = kinfo.GetInputCount(); | ||
const size_t num_outputs = kinfo.GetOutputCount(); | ||
|
||
// Number of inputs and outputs must match. | ||
if (ov_inputs.size() != num_inputs || ov_outputs.size() != num_outputs) { | ||
return false; | ||
} | ||
|
||
Ort::AllocatorWithDefaultOptions allocator; | ||
|
||
// Check input names, shapes, and element types. | ||
for (size_t i = 0; i < num_inputs; ++i) { | ||
const Ort::NodeArg ort_input = kinfo.GetInput(i); | ||
const auto& ov_input = ov_inputs[i]; | ||
|
||
if (!AreIONodesEqual(static_cast<OrtAllocator*>(allocator), ort_input, ov_input)) { | ||
return false; | ||
} | ||
} | ||
|
||
// Check output names, shapes, and element types. | ||
for (size_t i = 0; i < num_outputs; ++i) { | ||
const Ort::NodeArg ort_output = kinfo.GetOutput(i); | ||
const auto& ov_output = ov_outputs[i]; | ||
|
||
if (!AreIONodesEqual(static_cast<OrtAllocator*>(allocator), ort_output, ov_output)) { | ||
return false; | ||
} | ||
} | ||
|
||
return true; | ||
} | ||
|
||
KernelOpenVINO::KernelOpenVINO(const OrtApi& api, const OrtKernelInfo* info, const char* op_name) : ort_(api) { | ||
Ort::KernelInfo kinfo(info); | ||
|
||
// Extract OpenVINO .bin and .xml contents from node attributes. | ||
this->weights_ = kinfo.GetAttribute<std::string>("BIN"); | ||
std::string xml_contents = kinfo.GetAttribute<std::string>("XML"); | ||
|
||
// Create OpenVINO model. | ||
ov::Core core; | ||
const ov::Shape shape{this->weights_.size()}; | ||
const ov::Tensor weights_tensor(ov::element::u8, shape, weights_.data()); | ||
std::shared_ptr<ov::Model> model = core.read_model(xml_contents, weights_tensor); | ||
|
||
// Validate input/output shapes and types. | ||
this->ov_inputs_ = model->inputs(); | ||
this->ov_outputs_ = model->outputs(); | ||
|
||
if (!ValidateInputsAndOutputs(kinfo, this->ov_inputs_, this->ov_outputs_)) { | ||
// A more detailed error message would be better. | ||
ORT_CXX_API_THROW("I/O names, shapes, or element types do not match OpenVINO model.", ORT_INVALID_GRAPH); | ||
} | ||
|
||
// Get OpenVINO device type from provider options. | ||
Ort::OpWrapper::ProviderOptions opts = Ort::OpWrapper::ProviderOptions::FromKernelInfo(info, op_name); | ||
this->device_type_ = opts.HasOption("device_type") ? opts.GetOption("device_type") : "CPU"; | ||
|
||
// Compile OpenVINO model. | ||
this->compiled_model_ = core.compile_model(model, this->device_type_); | ||
} | ||
|
||
void KernelOpenVINO::Compute(OrtKernelContext* context) { | ||
// TODO: Add Ort::KernelContext class. | ||
const size_t num_inputs = this->ort_.KernelContext_GetInputCount(context); | ||
assert(num_inputs == this->ov_inputs_.size()); | ||
|
||
ov::TensorVector ov_inputs(num_inputs); | ||
|
||
// Gather OpenVINO model inputs. | ||
for (size_t i = 0; i < num_inputs; ++i) { | ||
const OrtValue* ort_val = this->ort_.KernelContext_GetInput(context, i); | ||
const auto& input_info = this->ov_inputs_[i]; | ||
|
||
const void* p_input_data = this->ort_.GetTensorData<void>(ort_val); | ||
ov_inputs[i] = ov::Tensor(input_info.get_element_type(), input_info.get_shape(), const_cast<void*>(p_input_data)); | ||
} | ||
|
||
// Inference. | ||
ov::InferRequest infer_req = this->compiled_model_.create_infer_request(); | ||
|
||
infer_req.set_input_tensors(ov_inputs); | ||
infer_req.infer(); | ||
|
||
const size_t num_outputs = this->ort_.KernelContext_GetOutputCount(context); | ||
assert(num_outputs == this->ov_outputs_.size()); | ||
|
||
// Copy inference results to ORT memory. | ||
for (size_t i = 0; i < num_outputs; ++i) { | ||
const auto& output_info = this->ov_outputs_[i]; | ||
|
||
// Get pointer to output data (src) from OpenVINO inference. | ||
ov::element::Type elem_type = output_info.get_element_type(); | ||
const void* src = infer_req.get_output_tensor(i).data(elem_type); | ||
|
||
// Get dst to which to copy result. | ||
const ov::Shape& ov_shape = output_info.get_shape(); | ||
std::vector<int64_t> shape(ov_shape.begin(), ov_shape.end()); | ||
OrtValue* ort_val = this->ort_.KernelContext_GetOutput(context, i, shape.data(), shape.size()); | ||
void* dst = this->ort_.GetTensorMutableData<void>(ort_val); | ||
|
||
// Copy data. | ||
size_t copy_size = elem_type.size() * ov::shape_size(ov_shape); | ||
std::memcpy(dst, src, copy_size); | ||
} | ||
} | ||
|
||
// | ||
// CustomOpOpenVINO | ||
// | ||
|
||
void* CustomOpOpenVINO::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { | ||
return new KernelOpenVINO(api, info, this->GetName()); | ||
} | ||
|
||
const char* CustomOpOpenVINO::GetName() const { return "OpenVINO_EP_Wrapper"; } | ||
|
||
size_t CustomOpOpenVINO::GetInputTypeCount() const { return 1; } | ||
|
||
ONNXTensorElementDataType CustomOpOpenVINO::GetInputType(size_t index) const { | ||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; | ||
} | ||
|
||
OrtCustomOpInputOutputCharacteristic CustomOpOpenVINO::GetInputCharacteristic(size_t index) const { | ||
return INPUT_OUTPUT_VARIADIC; | ||
} | ||
|
||
size_t CustomOpOpenVINO::GetOutputTypeCount() const { return 1; } | ||
|
||
ONNXTensorElementDataType CustomOpOpenVINO::GetOutputType(size_t index) const { | ||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; | ||
} | ||
|
||
OrtCustomOpInputOutputCharacteristic CustomOpOpenVINO::GetOutputCharacteristic(size_t index) const { | ||
return INPUT_OUTPUT_VARIADIC; | ||
} | ||
|
||
const char* CustomOpOpenVINO::GetExecutionProviderType() const { return "OpWrapperExecutionProvider"; } |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: remake onnx model to use a different domain name