forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cuda_execution_provider.h
165 lines (129 loc) · 5.15 KB
/
cuda_execution_provider.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "cuda_pch.h"
#include "core/platform/ort_mutex.h"
#include "core/graph/constants.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/execution_provider.h"
#include "core/providers/cuda/gpu_data_transfer.h"
#include "shared_inc/cuda_utils.h"
#include <deque>
namespace onnxruntime {
const int CPU_ALLOCATOR_DEVICE_ID = 0;
// Information needed to construct CUDA execution providers.
struct CUDAExecutionProviderInfo {
OrtDevice::DeviceId device_id{0};
size_t cuda_mem_limit{std::numeric_limits<size_t>::max()};
};
// Logical device representation.
class CUDAExecutionProvider : public IExecutionProvider {
public:
explicit CUDAExecutionProvider(const CUDAExecutionProviderInfo& info);
virtual ~CUDAExecutionProvider();
AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override;
Status Sync() const override;
Status OnRunStart() override;
Status OnRunEnd() override;
const void* GetExecutionHandle() const noexcept override {
// The CUDA interface does not return anything interesting.
return nullptr;
}
cublasHandle_t PerThreadCublasHandle() {
return GetPerThreadContext().CublasHandle();
}
cudnnHandle_t PerThreadCudnnHandle() {
return GetPerThreadContext().CudnnHandle();
}
curandGenerator_t PerThreadCurandGenerator() {
return GetPerThreadContext().CurandGenerator();
}
template <typename T>
const T* GetConstOnes(size_t count) {
return GetPerThreadContext().template GetConstOnes<T>(count);
}
void AddDeferredReleaseCPUPtr(void* p);
template <typename T>
inline IAllocatorUniquePtr<T> GetScratchBuffer(size_t count_or_bytes) const {
if (count_or_bytes == 0)
return nullptr;
return IAllocator::MakeUniquePtr<T>(GetAllocator(device_id_, OrtMemTypeDefault), count_or_bytes);
}
virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const override;
virtual std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const KernelRegistry*>& kernel_registries) const override;
int GetDeviceId() const { return device_id_; }
private:
OrtDevice::DeviceId device_id_;
size_t cuda_mem_limit_;
struct DeferredReleaseCPUPtrs {
bool recorded = false;
std::vector<void*> cpu_ptrs;
};
std::unordered_map<cudaEvent_t, DeferredReleaseCPUPtrs> deferred_release_cpu_ptr_;
OrtMutex deferred_release_cpu_ptr_mutex_;
class PerThreadContext final {
public:
PerThreadContext(OrtDevice::DeviceId device_id, size_t cuda_mem_limit);
~PerThreadContext();
cublasHandle_t CublasHandle() const {
return cublas_handle_;
}
cudnnHandle_t CudnnHandle() const {
return cudnn_handle_;
}
curandGenerator_t CurandGenerator() const {
return curand_generator_;
}
cudaEvent_t& GetCurrentDeferredReleaseEvent() {
return current_deferred_release_event_;
}
template <typename T>
const T* GetConstOnes(size_t count) {
if (std::is_same<T, float>::value) {
if (!constant_ones_float_) {
constant_ones_float_ = cuda::CreateConstantOnes<float>();
}
return reinterpret_cast<const T*>(constant_ones_float_->GetBuffer(count));
} else if (std::is_same<T, double>::value) {
if (!constant_ones_double_) {
constant_ones_double_ = cuda::CreateConstantOnes<double>();
}
return reinterpret_cast<const T*>(constant_ones_double_->GetBuffer(count));
} else if (std::is_same<T, half>::value) {
if (!constant_ones_half_) {
constant_ones_half_ = cuda::CreateConstantOnes<half>();
}
return reinterpret_cast<const T*>(constant_ones_half_->GetBuffer(count));
} else {
return nullptr;
}
}
AllocatorPtr GetAllocator() const {
return allocator_;
}
private:
cublasHandle_t cublas_handle_ = nullptr;
cudnnHandle_t cudnn_handle_ = nullptr;
curandGenerator_t curand_generator_ = nullptr;
// deferred release for temporary CPU pinned memory used in cudaMemcpyAsync
// note that cudaEvent will be assigned at OnRunEnd() when PerThreadContext destory
// so the ownership is passed to deferred_release_cpu_ptr_
cudaEvent_t current_deferred_release_event_ = nullptr;
std::unique_ptr<cuda::IConstantBuffer<float>> constant_ones_float_;
std::unique_ptr<cuda::IConstantBuffer<double>> constant_ones_double_;
std::unique_ptr<cuda::IConstantBuffer<half>> constant_ones_half_;
AllocatorPtr allocator_;
};
// thread local context during execution
using PerThreadContextMap = std::unordered_map<const CUDAExecutionProvider*, std::shared_ptr<PerThreadContext>>;
static thread_local std::unique_ptr<PerThreadContextMap> per_thread_context_map_;
// reuse thread local context
mutable std::deque<std::shared_ptr<PerThreadContext>> retired_context_pool_;
mutable OrtMutex context_pool_mutex_;
PerThreadContext& GetPerThreadContext() const;
void ReleasePerThreadStuffs() const;
};
} // namespace onnxruntime