Skip to content

Commit

Permalink
[Build] Propagate build option for CUDA minimal to TRT (microsoft#20695)
Browse files Browse the repository at this point in the history
### Description

Extend cuda minimal option to TRT provider, as with TRT 10 no linking to
cuDNN is required anymore
.
Besides that with the new engine dump feature it is also possible to
embed an engine in to an ONNX and not ship a builder lib.
In addition to that this has roughly the same deserialization
time/session setup time that using TRT standalone has.

### Motivation and Context

```
exe_builder_lib\onnxruntime_perf_test.exe -I -e tensorrt -r 5 -i 'trt_engine_cache_enable|1 trt_timing_cache_enable|1 trt_dump_ep_context_model|1 trt_weightless_engine_enable|1' model.onnx


exe_no_builder_lib\onnxruntime_perf_test.exe -I -e tensorrt -r 5 -i 'trt_engine_cache_enable|1 trt_timing_cache_enable|1 trt_dump_ep_context_model|1 trt_weightless_engine_enable|1' model_ctx.onnx
```
  • Loading branch information
gedoensmax authored Jul 9, 2024
1 parent 307b34a commit cc0de0d
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 20 deletions.
10 changes: 6 additions & 4 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,21 @@
)


if (onnxruntime_CUDA_MINIMAL)
set(onnxruntime_providers_cuda_shared_srcs "")
else()
if (NOT onnxruntime_CUDA_MINIMAL)
file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh"
)
else()
set(onnxruntime_providers_cuda_cu_srcs
"${ONNXRUNTIME_ROOT}/core/providers/cuda/math/unary_elementwise_ops_impl.cu"
)
endif()
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})

# disable contrib ops conditionally
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
if(NOT onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL)
if (NOT onnxruntime_ENABLE_ATEN)
list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc"
Expand Down
14 changes: 11 additions & 3 deletions cmake/onnxruntime_providers_tensorrt.cmake
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

if(onnxruntime_DISABLE_CONTRIB_OPS)
message( FATAL_ERROR "To compile TensorRT execution provider contrib ops have to be enabled to dump an engine using com.microsoft:EPContext node." )
endif()
add_definitions(-DUSE_TENSORRT=1)
if (onnxruntime_TENSORRT_PLACEHOLDER_BUILDER)
add_definitions(-DORT_TENSORRT_PLACEHOLDER_BUILDER)
Expand Down Expand Up @@ -154,8 +156,11 @@
# See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121
# However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries.
# Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}.
set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})

if(onnxruntime_CUDA_MINIMAL)
set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
else()
set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
endif()
file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.cc"
Expand Down Expand Up @@ -190,6 +195,9 @@
if (WIN32)
target_compile_options(onnxruntime_providers_tensorrt INTERFACE /wd4456)
endif()
if(onnxruntime_CUDA_MINIMAL)
target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE USE_CUDA_MINIMAL=1)
endif()

# Needed for the provider interface, as it includes training headers when training is enabled
if (onnxruntime_ENABLE_TRAINING_OPS)
Expand Down
11 changes: 5 additions & 6 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<float>* p) override { return p->Run(); }
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<double>* p) override { return p->Run(); }
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<MLFloat16>* p) override { return p->Run(); }
void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const override {
p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
}

#ifndef DISABLE_CONTRIB_OPS
Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) override {
Expand Down Expand Up @@ -294,12 +299,6 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }

void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const override {
p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
}

#ifdef ENABLE_ATEN
Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); }
#endif
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/core/providers/cpu/cpu_provider_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ struct ProviderHostCPU {
virtual Status Scan__Compute(const Scan<9>* p, OpKernelContext* ctx) = 0;
virtual Status Scan__SetupSubgraphExecutionInfo(Scan<8>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;
virtual Status Scan__SetupSubgraphExecutionInfo(Scan<9>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;

virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const = 0;
#ifndef DISABLE_CONTRIB_OPS
virtual Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) = 0;
virtual Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) = 0;
Expand Down Expand Up @@ -203,10 +205,6 @@ struct ProviderHostCPU {
virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0;
virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;

virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
gsl::span<const int64_t> input_dims,
InlinedVector<float>& scales) const = 0;

#ifdef ENABLE_ATEN
virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0;
#endif
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ CudaStream::CudaStream(cudaStream_t stream,
cudnn_handle_ = external_cudnn_handle;
CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream));
}
#else
(void)(external_cudnn_handle);
(void)(external_cublas_handle);
#endif
}

Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ void CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
}

#ifndef USE_CUDA_MINIMAL
template <>
Status CudaCall<cublasStatus_t, false>(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line) {
return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line);
Expand All @@ -306,6 +307,7 @@ template <>
void CudaCall<cudnnStatus_t, true>(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line) {
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
}
#endif

#if NV_TENSORRT_MAJOR >= 10
void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
Expand Down Expand Up @@ -1119,20 +1121,26 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
TensorrtExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) {
if (has_user_compute_stream) {
CUDA_CALL_THROW(cudaSetDevice(device_id));
#ifndef USE_CUDA_MINIMAL
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_)));
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream)));
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_)));
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream)));
#else
(void)(stream);
#endif
}
}

TensorrtExecutionProvider::PerThreadContext::~PerThreadContext() {
#ifndef USE_CUDA_MINIMAL
if (external_cublas_handle_ != nullptr) {
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_)));
}
if (external_cudnn_handle_ != nullptr) {
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_)));
}
#endif
trt_context_map_.clear();
}

Expand Down Expand Up @@ -1268,10 +1276,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (info.has_user_compute_stream) {
external_stream_ = true;
stream_ = static_cast<cudaStream_t>(info.user_compute_stream);
#ifndef USE_CUDA_MINIMAL
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_)));
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_)));
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_)));
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_)));
#endif
}

std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
Expand Down Expand Up @@ -1442,6 +1452,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (!ep_context_embed_mode_env.empty()) {
ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env);
}
// incase the EP context is dumped the engine cache has to be enabled
if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) {
engine_cache_enable_ = true;
}

enable_engine_cache_for_ep_context_model();

Expand Down Expand Up @@ -1737,8 +1751,10 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() {
}

if (external_stream_) {
#ifndef USE_CUDA_MINIMAL
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_)));
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_)));
#endif
}

if (!external_stream_ && stream_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

#pragma once
#include <ctime>
#ifndef USE_CUDA_MINIMAL
#include <cudnn.h>
#include <cublas_v2.h>

#else
typedef void* cudnnHandle_t;
typedef void* cublasHandle_t;
typedef void* cudnnStatus_t;
#endif
#include "core/providers/tensorrt/nv_includes.h"

#include "core/platform/ort_mutex.h"
Expand Down

0 comments on commit cc0de0d

Please sign in to comment.