Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT] Enable refitting an embedded engine when provided as byte stream #21357

Merged
merged 8 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct OrtTensorRTProviderOptionsV2 {
int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
// can be updated using: UpdateTensorRTProviderOptionsWithValue
int trt_max_partition_iterations{1000}; // maximum iterations for TensorRT parser to get capability
int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs
size_t trt_max_workspace_size{1 << 30}; // maximum workspace size for TensorRT.
Expand Down Expand Up @@ -78,6 +79,12 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
// the ONNX model containing the weights (applicable only when
// the "trt_weight_stripped_engine_enable" option is enabled)
const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights
// (applicable only when the "trt_weight_stripped_engine_enable"
// option is enabled)
// can be updated using: UpdateTensorRTProviderOptionsWithValue
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
// can be updated using: UpdateTensorRTProviderOptionsWithValue

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
Expand Down
24 changes: 23 additions & 1 deletion onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
// Only make path checks if model not provided as byte buffer
bool make_secure_path_checks = !GetModelPath(graph_viewer).empty();

if (embed_mode) {
// Get engine from byte stream.
const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s();
Expand All @@ -284,6 +287,23 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from binary data");
}

if (weight_stripped_engine_refit_) {
const std::string onnx_model_filename = attrs.at(ONNX_MODEL_FILENAME).s();
std::string placeholder;
auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
onnx_model_folder_path_,
placeholder,
make_secure_path_checks,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
(*trt_engine_).get(),
false /* serialize refitted engine to disk */,
detailed_build_log_);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
}
} else {
// Get engine from cache file.
std::string cache_path = attrs.at(EP_CACHE_CONTEXT).s();
Expand Down Expand Up @@ -343,7 +363,9 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
auto status = TensorrtExecutionProvider::RefitEngine(onnx_model_filename,
onnx_model_folder_path_,
weight_stripped_engine_cache,
true /* path check for security */,
make_secure_path_checks,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
(*trt_engine_).get(),
true /* serialize refitted engine to disk */,
detailed_build_log_);
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,17 @@ class TensorRTCacheModelHandler {
std::string compute_capability,
bool weight_stripped_engine_refit,
std::string onnx_model_folder_path,
const void* onnx_model_bytestream,
size_t onnx_model_bytestream_size,
bool detailed_build_log)
: trt_engine_(trt_engine),
trt_runtime_(trt_runtime),
ep_context_model_path_(ep_context_model_path),
compute_capability_(compute_capability),
weight_stripped_engine_refit_(weight_stripped_engine_refit),
onnx_model_folder_path_(onnx_model_folder_path),
onnx_model_bytestream_(onnx_model_bytestream),
onnx_model_bytestream_size_(onnx_model_bytestream_size),
detailed_build_log_(detailed_build_log) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);
Expand All @@ -74,6 +78,8 @@ class TensorRTCacheModelHandler {
std::string compute_capability_;
bool weight_stripped_engine_refit_;
std::string onnx_model_folder_path_;
const void* onnx_model_bytestream_;
size_t onnx_model_bytestream_size_;
bool detailed_build_log_;
}; // TRTCacheModelHandler
} // namespace onnxruntime
81 changes: 61 additions & 20 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,14 @@
engine_cache_enable_ = info.engine_cache_enable;
weight_stripped_engine_enable_ = info.weight_stripped_engine_enable;
onnx_model_folder_path_ = info.onnx_model_folder_path;
onnx_model_bytestream_ = info.onnx_bytestream;
onnx_model_bytestream_size_ = info.onnx_bytestream_size;
if ((onnx_model_bytestream_ != nullptr && onnx_model_bytestream_size_ == 0) ||
(onnx_model_bytestream_ == nullptr && onnx_model_bytestream_size_ != 0)) {
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"When providing either 'trt_onnx_bytestream_size' or "
"'trt_onnx_bytestream' both have to be provided"));
}
timing_cache_enable_ = info.timing_cache_enable;
force_timing_cache_match_ = info.force_timing_cache;
detailed_build_log_ = info.detailed_build_log;
Expand Down Expand Up @@ -1736,7 +1744,8 @@
<< ", trt_ep_context_file_path: " << ep_context_file_path_
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_;
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -2569,38 +2578,61 @@
std::string& onnx_model_folder_path,
std::string& weight_stripped_engine_cath_path,
bool path_check,
const void* onnx_model_bytestream,
size_t onnx_model_bytestream_size,
nvinfer1::ICudaEngine* trt_engine,
bool serialize_refitted_engine,
bool detailed_build_log) {
#if NV_TENSORRT_MAJOR >= 10
bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0;
std::filesystem::path onnx_model_path{onnx_model_folder_path};
onnx_model_path.append(onnx_model_filename);
if (path_check && IsAbsolutePath(onnx_model_path.string())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"For security purpose, the ONNX model path should be set with "
"a relative path, but it is an absolute path: " +
onnx_model_path.string());
}
if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model path has '..'. For security purpose, it's not "
"allowed to point outside the directory.");
}
if (refit_from_file) {
if (!onnx_model_filename.empty()) {
onnx_model_path.append(onnx_model_filename);
}
if (onnx_model_path.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,

Check warning on line 2594 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2594: Lines should be <= 120 characters long [whitespace/line_length] [2]
"The ONNX model was not provided as path. "
"Please use provide an ONNX bytestream to enable refitting the weightless engine.");
} else {
// check if file path to ONNX is legal
if (path_check && IsAbsolutePath(onnx_model_path.string())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"For security purpose, the ONNX model path should be set with "
"a relative path, but it is an absolute path: " +
onnx_model_path.string());
}
if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model path has '..'. For security purpose, it's not "
"allowed to point outside the directory.");
}

if (!std::filesystem::exists(onnx_model_path)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model " + onnx_model_path.string() +
" does not exist.");
if (!(std::filesystem::exists(onnx_model_path) && std::filesystem::is_regular_file(onnx_model_path))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model " + onnx_model_path.string() +
" does not exist.");
}
}
}

// weight-stripped engine refit logic
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log);
auto refitter = std::unique_ptr<nvinfer1::IRefitter>(nvinfer1::createInferRefitter(*trt_engine, trt_logger));
auto parser_refitter = std::unique_ptr<nvonnxparser::IParserRefitter>(
nvonnxparser::createParserRefitter(*refitter, trt_logger));
if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string());
if (refit_from_file) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string();
if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string());
}
} else {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from byte array";
if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem");
}
}
if (refitter->refitCudaEngine()) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Successfully refitted the weight-stripped engine.";
Expand Down Expand Up @@ -3177,10 +3209,15 @@
}

if (weight_stripped_engine_refit_) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build";
char* onnx = string_buf.data();
size_t onnx_size = string_buf.size();
auto status = RefitEngine(model_path_,
onnx_model_folder_path_,
engine_cache_path,
false /* path check for security */,
onnx,
onnx_size,
trt_engine.get(),
true /* serialize refitted engine to disk */,
detailed_build_log_);
Expand Down Expand Up @@ -3636,6 +3673,8 @@
onnx_model_folder_path_,
engine_cache_path,
false /* path check for security */,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
trt_engine,
true /* serialize refitted engine to disk */,
detailed_build_log_);
Expand Down Expand Up @@ -3854,6 +3893,8 @@
compute_capability_,
weight_stripped_engine_enable_,
onnx_model_folder_path_,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
detailed_build_log_);
auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer);
if (status != Status::OK()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,12 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;

/**
* Refit the weight-stripped engine
*/
static common::Status RefitEngine(std::string onnx_model_filename,
std::string& onnx_model_folder_path,
std::string& weight_stripped_engine_cath_path,
bool path_check,
const void* onnx_model_bytestream,
size_t onnx_model_bytestream_size,
nvinfer1::ICudaEngine* trt_engine,
bool serialize_refitted_engine,
bool detailed_build_log);
Expand All @@ -305,6 +304,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool weight_stripped_engine_enable_ = false;
bool weight_stripped_engine_refit_ = false;
std::string onnx_model_folder_path_;
const void* onnx_model_bytestream_;
size_t onnx_model_bytestream_size_;
bool build_heuristics_enable_ = false;
bool sparsity_enable_ = false;
int builder_optimization_level_ = 3;
Expand Down
Loading
Loading