Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use non blocking CUDA streams #1177

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/ctranslate2/devices.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ namespace ctranslate2 {
int get_device_index(Device device);
void set_device_index(Device device, int index);

void synchronize_device(Device device, int index);
void synchronize_stream(Device device);
void synchronize_device(Device device, int index = -1);
void synchronize_stream(Device device, int index = -1);

class ScopedDeviceSetter {
public:
Expand Down
58 changes: 38 additions & 20 deletions src/cuda/utils.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "./utils.h"

#include <atomic>
#include <array>
#include <cstdlib>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <vector>

Expand All @@ -13,6 +14,8 @@
namespace ctranslate2 {
namespace cuda {

constexpr int max_gpus = 16;

const char* cublasGetStatusName(cublasStatus_t status)
{
switch (status)
Expand Down Expand Up @@ -42,27 +45,16 @@ namespace ctranslate2 {
}
}

// We assign the default CUDA stream to the main thread since it can interact with
// multiple devices (e.g. load replicas on each GPU). The main thread is created
// before the others, so it will be the first to see the flag below set to true.
static std::atomic<bool> is_main_thread(true);

class CudaStream {
public:
CudaStream() {
if (is_main_thread) {
is_main_thread = false;
_stream = cudaStreamDefault;
} else {
CUDA_CHECK(cudaGetDevice(&_device));
CUDA_CHECK(cudaStreamCreate(&_stream));
}
CudaStream(int device)
: _device(device)
{
CUDA_CHECK(cudaStreamCreateWithFlags(&_stream, cudaStreamNonBlocking));
}
~CudaStream() {
if (_stream != cudaStreamDefault) {
ScopedDeviceSetter scoped_device_setter(Device::CUDA, _device);
cudaStreamDestroy(_stream);
}
ScopedDeviceSetter scoped_device_setter(Device::CUDA, _device);
cudaStreamDestroy(_stream);
}
cudaStream_t get() const {
return _stream;
Expand All @@ -72,6 +64,32 @@ namespace ctranslate2 {
cudaStream_t _stream;
};

// Pool of CUDA streams, one per device.
class CudaStreamPool {
public:
CudaStreamPool() {
if (get_gpu_count() > max_gpus)
throw std::runtime_error("Number of CUDA devices on the machine is larger than "
"the maximum supported number ("
+ std::to_string(max_gpus) + ")");
}

cudaStream_t get_device_stream() {
int device = 0;
CUDA_CHECK(cudaGetDevice(&device));

std::call_once(_init_streams[device], [this, device]() {
_streams[device] = std::make_unique<CudaStream>(device);
});

return _streams[device]->get();
}

private:
std::array<std::unique_ptr<CudaStream>, max_gpus> _streams;
std::array<std::once_flag, max_gpus> _init_streams;
};

class CublasHandle {
public:
CublasHandle() {
Expand All @@ -95,8 +113,8 @@ namespace ctranslate2 {
// when the thread exits.

cudaStream_t get_cuda_stream() {
static thread_local CudaStream cuda_stream;
return cuda_stream.get();
static thread_local CudaStreamPool cuda_stream_pool;
return cuda_stream_pool.get_device_stream();
}

cublasHandle_t get_cublas_handle() {
Expand Down
17 changes: 13 additions & 4 deletions src/devices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,28 @@ namespace ctranslate2 {
void synchronize_device(Device device, int index) {
#ifdef CT2_WITH_CUDA
if (device == Device::CUDA) {
const ScopedDeviceSetter scoped_device_setter(device, index);
cudaDeviceSynchronize();
if (index >= 0) {
const ScopedDeviceSetter scoped_device_setter(device, index);
cudaDeviceSynchronize();
} else {
cudaDeviceSynchronize();
}
}
#else
(void)device;
(void)index;
#endif
}

void synchronize_stream(Device device) {
void synchronize_stream(Device device, int index) {
#ifdef CT2_WITH_CUDA
if (device == Device::CUDA) {
cudaStreamSynchronize(cuda::get_cuda_stream());
if (index >= 0) {
const ScopedDeviceSetter scoped_device_setter(device, index);
cudaStreamSynchronize(cuda::get_cuda_stream());
} else {
cudaStreamSynchronize(cuda::get_cuda_stream());
}
}
#else
(void)device;
Expand Down
9 changes: 6 additions & 3 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,15 @@ namespace ctranslate2 {
if (src_device != Device::CPU) {
ScopedDeviceSetter scoped_device_setter(src_device, src_device_index);
move_variables_to_device(variables, Device::CPU);
synchronize_stream(src_device);
}

// Move variables to the destination device.
if (dst_device != Device::CPU) {
ScopedDeviceSetter scoped_device_setter(dst_device, dst_device_index);
move_variables_to_device(variables, dst_device);
synchronize_stream(dst_device);
}

synchronize_device(src_device, src_device_index); // Wait for asynchronous deallocations.
}

static StorageView copy_variable(const StorageView& variable,
Expand All @@ -108,6 +108,7 @@ namespace ctranslate2 {
if (variable.device() != Device::CPU) {
ScopedDeviceSetter scoped_device_setter(variable.device(), variable.device_index());
copy = variable.to(Device::CPU);
synchronize_stream(variable.device());
}

if (device != Device::CPU) {
Expand All @@ -116,6 +117,7 @@ namespace ctranslate2 {
copy = copy.to(device);
else
copy = variable.to(device);
synchronize_stream(device);
}

return copy;
Expand All @@ -133,7 +135,7 @@ namespace ctranslate2 {
Model::~Model() {
if (!_variable_index.empty()) {
_variable_index.clear();
synchronize_device(_device, _device_index); // Wait for asynchronous deallocations.
synchronize_stream(_device, _device_index); // Wait for asynchronous deallocations.
}
}

Expand Down Expand Up @@ -509,6 +511,7 @@ namespace ctranslate2 {
const ScopedDeviceSetter scoped_device_setter(device, device_index);
model->process_linear_weights();
model->initialize(model_reader);
synchronize_stream(device);
return model;
}

Expand Down