Skip to content

Commit

Permalink
Enrich cuda resources with ep options (#19014)
Browse files Browse the repository at this point in the history
Allow custom ops to access cuda ep options.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Jan 11, 2024
1 parent 58bf836 commit 24e9daf
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 48 deletions.
59 changes: 33 additions & 26 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,45 @@ struct CudaContext : public CustomOpContext {
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
OrtAllocator* deferred_cpu_allocator = {};
// below are cuda ep options
int16_t device_id = 0;
int32_t arena_extend_strategy = 0;
int32_t cudnn_conv_algo_search = 0;
bool cudnn_conv_use_max_workspace = true;
bool cudnn_conv1d_pad_to_nc1d = false;
bool enable_skip_layer_norm_strict_mode = false;
bool prefer_nhwc = false;

void Init(const OrtKernelContext& kernel_ctx) {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cuda_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cuda stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cuda_stream = reinterpret_cast<cudaStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cudnn_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cudnn handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cudnn_handle = reinterpret_cast<cudnnHandle_t>(resource);
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
cudnn_handle = FetchResource<cudnnHandle_t>(kernel_ctx, CudaResource::cudnn_handle_t);
cublas_handle = FetchResource<cublasHandle_t>(kernel_ctx, CudaResource::cublas_handle_t);
deferred_cpu_allocator = FetchResource<OrtAllocator*>(kernel_ctx, CudaResource::deferred_cpu_allocator_t);

device_id = FetchResource<int16_t>(kernel_ctx, CudaResource::device_id_t);
arena_extend_strategy = FetchResource<int32_t>(kernel_ctx, CudaResource::arena_extend_strategy_t);
cudnn_conv_algo_search = FetchResource<int32_t>(kernel_ctx, CudaResource::cudnn_conv_algo_search_t);
cudnn_conv_use_max_workspace = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t);

cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
enable_skip_layer_norm_strict_mode = FetchResource<bool>(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
}

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cublas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
template <typename T>
T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
if (sizeof(T) > sizeof(void*)) {
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
deferred_cpu_allocator = reinterpret_cast<OrtAllocator*>(resource);
T t = {};
memcpy(&t, &resource, sizeof(T));
return t;
}

void* AllocDeferredCpuMem(size_t size) const {
Expand Down
12 changes: 10 additions & 2 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 2
#define ORT_CUDA_RESOUCE_VERSION 3

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cuda_stream_t = cuda_resource_offset, // 10000
cudnn_handle_t,
cublas_handle_t,
deferred_cpu_allocator_t,
// below are cuda ep options
device_id_t, // 10004
arena_extend_strategy_t,
cudnn_conv_algo_search_t,
cudnn_conv_use_max_workspace_t,
cudnn_conv1d_pad_to_nc1d_t,
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
};
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4418,7 +4418,7 @@ struct OrtApi {
ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);

/**
* Get a EP resoure.
* Get a EP resource.
* E.g. a cuda stream or a cublas handle
*
* \param context - Kernel context
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2465,7 +2465,8 @@ void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry&
stream_,
use_ep_level_unified_stream_,
GetPerThreadContext().CudnnHandle(),
GetPerThreadContext().CublasHandle());
GetPerThreadContext().CublasHandle(),
info_);
}

OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
Expand Down
45 changes: 35 additions & 10 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ CudaStream::CudaStream(cudaStream_t stream,
bool release_cpu_buffer_on_cuda_stream,
bool own_flag,
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublas_handle) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
deferred_cpu_allocator_(*this) {
cublasHandle_t external_cublas_handle,
const CUDAExecutionProviderInfo& ep_info) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
deferred_cpu_allocator_(*this),
ep_info_(ep_info) {
if (own_flag) {
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
Expand Down Expand Up @@ -185,6 +187,27 @@ void* CudaStream::GetResource(int version, int id) const {
case CudaResource::deferred_cpu_allocator_t:
return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_);
break;
case CudaResource::device_id_t:
return reinterpret_cast<void*>(ep_info_.device_id);
break;
case CudaResource::arena_extend_strategy_t:
return reinterpret_cast<void*>(ep_info_.arena_extend_strategy);
break;
case CudaResource::cudnn_conv_algo_search_t:
return reinterpret_cast<void*>(ep_info_.cudnn_conv_algo_search);
break;
case CudaResource::cudnn_conv_use_max_workspace_t:
return reinterpret_cast<void*>(ep_info_.cudnn_conv_use_max_workspace);
break;
case CudaResource::cudnn_conv1d_pad_to_nc1d_t:
return reinterpret_cast<void*>(ep_info_.cudnn_conv1d_pad_to_nc1d);
break;
case CudaResource::enable_skip_layer_norm_strict_mode_t:
return reinterpret_cast<void*>(ep_info_.enable_skip_layer_norm_strict_mode);
break;
case CudaResource::prefer_nhwc_t:
return reinterpret_cast<void*>(ep_info_.prefer_nhwc);
break;
default:
break;
}
Expand All @@ -207,26 +230,28 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
cudaStream_t external_stream,
bool use_existing_stream,
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublas_handle) {
cublasHandle_t external_cublas_handle,
const CUDAExecutionProviderInfo& ep_info) {
// wait cuda notification on cuda ep
stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitCudaNotificationOnDevice);
// wait cuda notification on cpu ep
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCudaNotificationOnHost);
if (!use_existing_stream)
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream](const OrtDevice& device) {
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream, ep_info](const OrtDevice& device) {
CUDA_CALL_THROW(cudaSetDevice(device.Id()));
cudaStream_t stream = nullptr;
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
// CUDA_CALL_THROW(cudaStreamCreate(&stream));
return std::make_unique<CudaStream>(stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, true, nullptr, nullptr);
return std::make_unique<CudaStream>(stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, true, nullptr, nullptr, ep_info);
});
else
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator,
release_cpu_buffer_on_cuda_stream,
external_stream,
external_cudnn_handle,
external_cublas_handle](const OrtDevice& device) {
return std::make_unique<CudaStream>(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle);
external_cublas_handle,
ep_info](const OrtDevice& device) {
return std::make_unique<CudaStream>(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle, ep_info);
});
}

Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/framework/stream_handles.h"
#include "core/providers/cuda/cuda_execution_provider_info.h"

namespace onnxruntime {

Expand All @@ -23,7 +24,8 @@ struct CudaStream : Stream {
bool release_cpu_buffer_on_cuda_stream,
bool own_flag,
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublass_handle);
cublasHandle_t external_cublass_handle,
const CUDAExecutionProviderInfo& ep_info);

~CudaStream();

Expand All @@ -50,6 +52,7 @@ struct CudaStream : Stream {
AllocatorPtr cpu_allocator_;
bool release_cpu_buffer_on_cuda_stream_{true};
DeferredCpuAllocator deferred_cpu_allocator_;
const CUDAExecutionProviderInfo ep_info_;
};

void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
Expand All @@ -59,6 +62,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
cudaStream_t external_stream,
bool use_existing_stream,
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublass_handle);
cublasHandle_t external_cublass_handle,
const CUDAExecutionProviderInfo& ep_info);
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -3473,7 +3473,8 @@ void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis
stream_,
external_stream_ /* use_existing_stream */,
external_cudnn_handle_,
external_cublas_handle_);
external_cublas_handle_,
{});
}

OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,6 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelCont
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Failed to fetch a stream hosting the requested resource");
}
*resource = stream->GetResource(resource_version, resource_id);
if (!(*resource)) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Requested resource does not exist");
}
return nullptr;
API_IMPL_END
};
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
const Ort::Custom::Tensor<float>& X,
const Ort::Custom::Tensor<float>& Y,
Ort::Custom::Tensor<float>& Z) {
auto input_shape = X.Shape();
CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream");
CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle");
CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle");
CUSTOM_ENFORCE(cuda_ctx.arena_extend_strategy == 0, "arena_extend_strategy mismatch");
void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t));
CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator");
cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);
auto z_raw = Z.Allocate(input_shape);
auto z_raw = Z.Allocate(X.Shape());
cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream);
}

Expand Down

0 comments on commit 24e9daf

Please sign in to comment.