From cc0de0d5262cd1a76532b79ddde4a7d799f840b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Tue, 9 Jul 2024 23:40:04 +0200 Subject: [PATCH] [Build] Propagate build option for CUDA minimal to TRT (#20695) ### Description Extend cuda minimal option to TRT provider, as with TRT 10 no linking to cuDNN is required anymore . Besides that with the new engine dump feature it is also possible to embed an engine in to an ONNX and not ship a builder lib. In addition to that this has roughly the same deserialization time/session setup time that using TRT standalone has. ### Motivation and Context ``` exe_builder_lib\onnxruntime_perf_test.exe -I -e tensorrt -r 5 -i 'trt_engine_cache_enable|1 trt_timing_cache_enable|1 trt_dump_ep_context_model|1 trt_weightless_engine_enable|1' model.onnx exe_no_builder_lib\onnxruntime_perf_test.exe -I -e tensorrt -r 5 -i 'trt_engine_cache_enable|1 trt_timing_cache_enable|1 trt_dump_ep_context_model|1 trt_weightless_engine_enable|1' model_ctx.onnx ``` --- cmake/onnxruntime_providers_cuda.cmake | 10 ++++++---- cmake/onnxruntime_providers_tensorrt.cmake | 14 +++++++++++--- .../core/providers/cpu/cpu_provider_shared.cc | 11 +++++------ .../core/providers/cpu/cpu_provider_shared.h | 8 +++----- .../core/providers/cuda/cuda_stream_handle.cc | 3 +++ .../tensorrt/tensorrt_execution_provider.cc | 16 ++++++++++++++++ .../tensorrt/tensorrt_execution_provider.h | 8 ++++++-- 7 files changed, 50 insertions(+), 20 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 3b48a40bf1166..82c31ce6b6b4d 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -33,19 +33,21 @@ ) - if (onnxruntime_CUDA_MINIMAL) - set(onnxruntime_providers_cuda_shared_srcs "") - else() + if (NOT onnxruntime_CUDA_MINIMAL) file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" ) + else() + set(onnxruntime_providers_cuda_cu_srcs + "${ONNXRUNTIME_ROOT}/core/providers/cuda/math/unary_elementwise_ops_impl.cu" + ) endif() source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) # disable contrib ops conditionally - if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(NOT onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL) if (NOT onnxruntime_ENABLE_ATEN) list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc" diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 90203216600fa..3d46c139feea9 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - + if(onnxruntime_DISABLE_CONTRIB_OPS) + message( FATAL_ERROR "To compile TensorRT execution provider contrib ops have to be enabled to dump an engine using com.microsoft:EPContext node." ) + endif() add_definitions(-DUSE_TENSORRT=1) if (onnxruntime_TENSORRT_PLACEHOLDER_BUILDER) add_definitions(-DORT_TENSORRT_PLACEHOLDER_BUILDER) @@ -154,8 +156,11 @@ # See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121 # However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries. # Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}. - set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) - + if(onnxruntime_CUDA_MINIMAL) + set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + else() + set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY}) + endif() file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h" "${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.cc" @@ -190,6 +195,9 @@ if (WIN32) target_compile_options(onnxruntime_providers_tensorrt INTERFACE /wd4456) endif() + if(onnxruntime_CUDA_MINIMAL) + target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE USE_CUDA_MINIMAL=1) + endif() # Needed for the provider interface, as it includes training headers when training is enabled if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index c4a83efa01a91..fd7b19dea724d 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -192,6 +192,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor* p) override { return p->Run(); } + void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const override { + p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); + } #ifndef DISABLE_CONTRIB_OPS Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) override { @@ -294,12 +299,6 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } - void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, - gsl::span input_dims, - InlinedVector& scales) const override { - p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); - } - #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index c0e674827e4d1..840d6f8e3e7aa 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -141,7 +141,9 @@ struct ProviderHostCPU { virtual Status Scan__Compute(const Scan<9>* p, OpKernelContext* ctx) = 0; virtual Status Scan__SetupSubgraphExecutionInfo(Scan<8>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; virtual Status Scan__SetupSubgraphExecutionInfo(Scan<9>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; - + virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const = 0; #ifndef DISABLE_CONTRIB_OPS virtual Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) = 0; virtual Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) = 0; @@ -203,10 +205,6 @@ struct ProviderHostCPU { virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; - virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, - gsl::span input_dims, - InlinedVector& scales) const = 0; - #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 3c0bf183362dd..58e57572131b1 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -81,6 +81,9 @@ CudaStream::CudaStream(cudaStream_t stream, cudnn_handle_ = external_cudnn_handle; CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream)); } +#else + (void)(external_cudnn_handle); + (void)(external_cublas_handle); #endif } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 9a2d431badbb5..8a601c156bd0a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -287,6 +287,7 @@ void CudaCall(cudaError retCode, const char* exprString, const return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } +#ifndef USE_CUDA_MINIMAL template <> Status CudaCall(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line) { return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line); @@ -306,6 +307,7 @@ template <> void CudaCall(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line) { return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } +#endif #if NV_TENSORRT_MAJOR >= 10 void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, @@ -1119,20 +1121,26 @@ Status BindKernelOutput(Ort::KernelContext& ctx, TensorrtExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) { if (has_user_compute_stream) { CUDA_CALL_THROW(cudaSetDevice(device_id)); +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream))); +#else + (void)(stream); +#endif } } TensorrtExecutionProvider::PerThreadContext::~PerThreadContext() { +#ifndef USE_CUDA_MINIMAL if (external_cublas_handle_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); } if (external_cudnn_handle_ != nullptr) { ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); } +#endif trt_context_map_.clear(); } @@ -1268,10 +1276,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (info.has_user_compute_stream) { external_stream_ = true; stream_ = static_cast(info.user_compute_stream); +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_))); +#endif } std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; @@ -1442,6 +1452,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv if (!ep_context_embed_mode_env.empty()) { ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env); } + // incase the EP context is dumped the engine cache has to be enabled + if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + engine_cache_enable_ = true; + } enable_engine_cache_for_ep_context_model(); @@ -1737,8 +1751,10 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() { } if (external_stream_) { +#ifndef USE_CUDA_MINIMAL ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_))); ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_))); +#endif } if (!external_stream_ && stream_) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index ec140579569b9..b58e86237860c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -3,9 +3,13 @@ #pragma once #include +#ifndef USE_CUDA_MINIMAL #include -#include - +#else +typedef void* cudnnHandle_t; +typedef void* cublasHandle_t; +typedef void* cudnnStatus_t; +#endif #include "core/providers/tensorrt/nv_includes.h" #include "core/platform/ort_mutex.h"