Skip to content

Commit

Permalink
fix review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergei-Lebedev committed Nov 9, 2021
1 parent 84711a8 commit 37229ce
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 29 deletions.
12 changes: 9 additions & 3 deletions include/torch_ucc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class ProcessGroupUCC : public ProcessGroup {
std::exception_ptr eptr_;
std::vector<ucp_ep_h> *eps;
int rank;
int comm_id;
};

class WorkUCC : public ProcessGroup::Work {
Expand Down Expand Up @@ -286,7 +287,7 @@ class CommPG {
bool stop_progress_loop;
bool collective_inprogress;

void check_communicator_status(int my_rank, uint64_t seq_num,
void check_communicator_status(int my_rank, int comm_id, uint64_t seq_num,
std::vector<ucp_ep_h> *eps);
public:
c10::DeviceIndex cuda_device_index;
Expand Down Expand Up @@ -320,7 +321,11 @@ class CommPG {
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team,
ucc_ee_h ee);
ucc_ee_h ee,
std::vector<ucp_ep_h> *eps,
int rank,
int comm_id);

#endif

void enqueue_collective(
Expand All @@ -329,7 +334,8 @@ class CommPG {
ucc_coll_args_t& coll,
ucc_team_h team,
std::vector<ucp_ep_h> *eps,
int rank);
int rank,
int comm_id);

static std::shared_ptr<CommPG> get_comm(
uint32_t& id,
Expand Down
1 change: 1 addition & 0 deletions include/torch_ucc_comm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ enum torch_ucc_rank_state_t {
TORCH_UCC_RANK_STATE_COLLECTIVE_NOT_POSTED,
TORCH_UCC_RANK_STATE_COLLECTIVE_INPROGRESS,
TORCH_UCC_RANK_STATE_COLLECTIVE_TIMEOUT,
TORCH_UCC_RANK_STATE_DEVICE_ERROR,
TORCH_UCC_RANK_STATE_COLLECTIVE_DONE
};

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
else:
print("CUDA support is enabled")
plugin_libraries.append("cuda")
plugin_compile_args.append("-DUSE_CUDA")
module = cpp_extension.CUDAExtension(
name = "torch_ucc",
Expand Down
77 changes: 60 additions & 17 deletions src/torch_ucc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct torch_ucc_config_t {
std::array<bool, 32> blocking_wait;
bool enable_profiling;
bool use_future;
int comm_state_check_timeout;
} torch_ucc_config;

void read_confg() {
Expand Down Expand Up @@ -103,6 +104,11 @@ void read_confg() {
if (env) {
torch_ucc_config.enable_profiling = std::atoi(env);
}
torch_ucc_config.comm_state_check_timeout = 2000;
env = std::getenv("TORCH_UCC_COMM_CHECK_TIMEOUT");
if (env) {
torch_ucc_config.comm_state_check_timeout = std::atoi(env);
}
}

void check_device(c10::Device dev1, c10::Device dev2) {
Expand Down Expand Up @@ -215,6 +221,27 @@ ucs_status_t torch_ucc_timeout_am_cb(
torch_ucc_timeout_desc_t *dsc = (torch_ucc_timeout_desc_t*)header;
CommPG *comm = (CommPG*)arg;
torch_ucc_rank_state_t state;
ucp_request_param_t params;
ucp_tag_t ucp_tag;

#ifdef USE_CUDA
if (comm->cuda_device_index >= 0) {
CUdevice device;
CUresult st;
unsigned int flags;
int active;
st = cuDeviceGet(&device, comm->cuda_device_index);
if (st != CUDA_SUCCESS) {
state = TORCH_UCC_RANK_STATE_DEVICE_ERROR;
goto send_response;
}
st = cuDevicePrimaryCtxGetState(device, &flags, &active);
if (st != CUDA_SUCCESS || !active) {
state = TORCH_UCC_RANK_STATE_DEVICE_ERROR;
goto send_response;
}
}
#endif

state = TORCH_UCC_RANK_STATE_COLLECTIVE_NOT_POSTED;
if (comm->seq_num > dsc->seq_num) {
Expand All @@ -233,9 +260,9 @@ ucs_status_t torch_ucc_timeout_am_cb(
}
}
}
ucp_tag_t ucp_tag;

send_response:
TORCH_UCX_MAKE_OOB_SEND_TAG(ucp_tag, 0, dsc->rank, dec->comm_id);
ucp_request_param_t params;
params.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_DATATYPE;
params.datatype = ucp_dt_make_contig(sizeof(torch_ucc_rank_state_t));
Expand Down Expand Up @@ -528,7 +555,8 @@ void CommPG::enqueue_collective(
ucc_coll_args_t& coll,
ucc_team_h team,
std::vector<ucp_ep_h> *eps,
int rank) {
int rank,
int comm_id) {
ucc_coll_req_h request;
TORCH_UCC_CHECK(
ucc_collective_init(&coll, &request, team), "failed to init collective");
Expand All @@ -540,6 +568,7 @@ void CommPG::enqueue_collective(
entry->future_ = work->getFuture();
entry->eps = eps;
entry->rank = rank;
entry->comm_id = comm_id;
work->entry_ = entry;
std::unique_lock<std::mutex> lock(mutex);
progress_queue.push_back(entry);
Expand All @@ -553,7 +582,10 @@ void CommPG::enqueue_cuda_collective(
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team,
ucc_ee_h ee) {
ucc_ee_h ee,
std::vector<ucp_ep_h> *eps,
int rank,
int comm_id) {
ucc_coll_req_h request;
TORCH_UCC_CHECK(
ucc_collective_init(&coll, &request, team),
Expand All @@ -570,8 +602,11 @@ void CommPG::enqueue_cuda_collective(
TORCH_CHECK(st == UCC_OK && post_ev->ev_type == UCC_EVENT_COLLECTIVE_POST);
ucc_ee_ack_event(ee, post_ev);
auto entry =
std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request);
std::make_shared<ProcessGroupUCC::ProgressEntry>(&ucc_comm, request, seq_num++);
entry->data = std::move(data);
entry->eps = eps;
entry->rank = rank;
entry->comm_id = comm_id;
work->entry_ = entry;
std::unique_lock<std::mutex> lock(mutex);
progress_queue.push_back(entry);
Expand All @@ -582,9 +617,9 @@ void CommPG::enqueue_cuda_collective(

void CommPG::check_communicator_status(
int my_rank,
int comm_id,
uint64_t seq_num,
std::vector<ucp_ep_h> *eps) {
ucs_status_ptr_t st;
ucp_request_param_t params_am, params_recv;
torch_ucc_timeout_desc_t dsc;

Expand All @@ -595,17 +630,21 @@ void CommPG::check_communicator_status(
static_cast<ucc_coll_req_h>(request)->status = UCC_OK;
};
dsc.seq_num = seq_num;
dsc.comm_id = 0;
for (auto i = 0; i < eps->size(); i++) {
dsc.comm_id = comm_id;
for (auto i = 0; i < (int)eps->size(); i++) {
if (i == my_rank) {
comm_state[i] = TORCH_UCC_RANK_STATE_COLLECTIVE_TIMEOUT;
continue;
}
comm_state[i] = TORCH_UCC_RANK_STATE_NOT_RESPONDIG;
dsc.rank = i;
st = ucp_am_send_nbx((*eps)[i], TORCH_UCC_TIMEOUT_AM_ID, &dsc, sizeof(dsc),
nullptr, 0ul, &params_am);
ucp_am_send_nbx((*eps)[i], TORCH_UCC_TIMEOUT_AM_ID, &dsc, sizeof(dsc),
nullptr, 0ul, &params_am);
ucp_tag_t ucp_tag, ucp_tag_mask;
TORCH_UCX_MAKE_OOB_RECV_TAG(ucp_tag, ucp_tag_mask, 0, i, dsc.comm_id);
params_recv.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK |
UCP_OP_ATTR_FIELD_DATATYPE;;
params_recv.datatype = ucp_dt_make_contig(sizeof(ucc_status_t));
UCP_OP_ATTR_FIELD_DATATYPE;
params_recv.datatype = ucp_dt_make_contig(sizeof(torch_ucc_rank_state_t));
params_recv.cb.recv = [](void* request,
ucs_status_t status,
const ucp_tag_recv_info_t* info,
Expand All @@ -617,12 +656,13 @@ void CommPG::check_communicator_status(
}
auto start = std::chrono::system_clock::now();
auto end = std::chrono::system_clock::now();
while((std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() <= 2000))
while (std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() <=
torch_ucc_config.comm_state_check_timeout)
{
end = std::chrono::system_clock::now();
ucx_comm.progress();
}
for (auto i = 0; i < eps->size(); i++) {
for (auto i = 0; i < (int)eps->size(); i++) {
if (comm_state[i] != TORCH_UCC_RANK_STATE_COLLECTIVE_DONE) {
std::string err_log = c10::str(
"on rank ",
Expand Down Expand Up @@ -675,7 +715,8 @@ void CommPG::progress_loop() {
eptr = std::current_exception();
}
if (work->request_->status == UCC_ERR_TIMED_OUT) {
check_communicator_status(work->rank, work->seq_num_, work->eps);
check_communicator_status(work->rank, work->comm_id, work->seq_num_,
work->eps);
}
work->finalize(eptr);
work = nullptr;
Expand Down Expand Up @@ -756,7 +797,8 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::collective_post(
work->future_ = c10::make_intrusive<at::ivalue::Future>(
c10::ListType::create(c10::TensorType::get()));
}
comm->enqueue_collective(std::move(data), work, coll, team, &eps, rank_);
comm->enqueue_collective(std::move(data), work, coll, team, &eps, rank_,
comm_id);
return work;
}
#ifdef USE_CUDA
Expand All @@ -773,7 +815,8 @@ c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::collective_post(
}
cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
cuda_ev->block(*stream);
comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee);
comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee,
&eps, rank_, comm_id);
cuda_ev->record(*stream);
work->fence = std::move(cuda_ev);
work->ep = &ep;
Expand Down
13 changes: 4 additions & 9 deletions src/torch_ucc_comm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ const char *torch_ucc_rank_state_string(torch_ucc_rank_state_t state)
return "Collective timeout was triggered";
case TORCH_UCC_RANK_STATE_COLLECTIVE_DONE:
return "Collective was finished";
case TORCH_UCC_RANK_STATE_DEVICE_ERROR:
return "Device error";
default:
return "Unknown state";
};
Expand Down Expand Up @@ -71,15 +73,8 @@ CommUCX::CommUCX(int comm_size, const c10::intrusive_ptr<ProcessGroupUCCLogger>&
}

void CommUCX::set_am_recv_handler(const ucp_am_handler_param_t *params) {
ucs_status_t st;

st = ucp_worker_set_am_recv_handler(worker, params);
if (st != UCS_OK) {
logger->logError(
TORCH_UCC_INIT,
c10::str("UCX failed to set am handler:", ucs_status_string(st)));
throw std::runtime_error(ucs_status_string(st));
}
TORCH_UCX_CHECK(ucp_worker_set_am_recv_handler(worker, params),
"UCX failed to set am handler");
}

void CommUCX::progress() {
Expand Down

0 comments on commit 37229ce

Please sign in to comment.