Skip to content

Commit

Permalink
consensus protocol (facebookresearch#39)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#39

Differential Revision: https://www.internalfb.com/diff/D31765741?entry_point=27

Pulled By: kingchc

fbshipit-source-id: eb987398fc22f12200273934513cbdc24b5bc93f
  • Loading branch information
Sergei-Lebedev authored and facebook-github-bot committed Nov 19, 2021
1 parent e88fd7b commit 8c5abce
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 65 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ jobs:
/bin/bash ./test/start_test.sh ./test/torch_work_test.py --backend=gloo
echo "UCC pt2pt"
/bin/bash ./test/start_test.sh ./test/torch_pt2pt_test.py --backend=gloo
echo "UCC timeout test"
/bin/bash ./test/start_test.sh ./test/torch_timeout_test.py --backend=gloo
- name: Test PARAM
run: |
git clone ${PARAM_LINK} /tmp/param
Expand Down
73 changes: 28 additions & 45 deletions include/torch_ucc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,47 +35,6 @@ namespace c10d {

#define TORCH_UCC_DEVICE_NOT_SET -2

#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))

#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))

#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
} while (0)

#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)

#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
} else { \
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
} \
} while (0)

#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
} while (0)

#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
(_ucp_tag_mask) = (uint64_t)-1; \
} while (0)

#ifdef USE_CUDA
#define SAVE_TENSORS(_TENSORS, _DATA) \
do { \
Expand Down Expand Up @@ -147,15 +106,21 @@ class ProcessGroupUCC : public ProcessGroup {
public:
ProgressEntry(
CommBase* comm,
ucc_coll_req_h request)
: status_(UCC_INPROGRESS), comm_(comm), request_(request) {}
ucc_coll_req_h request,
uint64_t seq_num)
: status_(UCC_INPROGRESS), comm_(comm), request_(request),
seq_num_(seq_num) {}
void finalize(std::exception_ptr eptr = nullptr);
ucc_status_t status_;
CommBase* comm_;
ucc_coll_req_h request_;
uint64_t seq_num_;
std::unique_ptr<WorkData> data;
c10::intrusive_ptr<c10::ivalue::Future> future_;
std::exception_ptr eptr_;
std::vector<ucp_ep_h> *eps;
int rank;
int comm_id;
};

class WorkUCC : public ProcessGroup::Work {
Expand Down Expand Up @@ -326,8 +291,10 @@ class ProcessGroupUCC : public ProcessGroup {

class CommPG {
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
std::vector<torch_ucc_rank_state_t> comm_state;
CommUCX ucx_comm;
CommUCC ucc_comm;
uint64_t seq_num;
c10::DeviceIndex device_index;
std::mutex mutex;
std::thread progress_thread;
Expand All @@ -337,6 +304,8 @@ class CommPG {
bool stop_progress_loop;
bool collective_inprogress;

void check_communicator_status(int my_rank, int comm_id, uint64_t seq_num,
std::vector<ucp_ep_h> *eps);
public:
c10::DeviceIndex cuda_device_index;
CommPG(const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
Expand Down Expand Up @@ -369,14 +338,21 @@ class CommPG {
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team,
ucc_ee_h ee);
ucc_ee_h ee,
std::vector<ucp_ep_h> *eps,
int rank,
int comm_id);

#endif

void enqueue_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team);
ucc_team_h team,
std::vector<ucp_ep_h> *eps,
int rank,
int comm_id);

static std::shared_ptr<CommPG> get_comm(
uint32_t& id,
Expand All @@ -399,6 +375,13 @@ class CommPG {
size_t size,
ucp_tag_t ucp_tag,
ucp_tag_t ucp_tag_mask);
friend ucs_status_t torch_ucc_timeout_am_cb(
void *arg,
const void *header,
size_t header_length,
void *data,
size_t length,
const ucp_am_recv_param_t *param);
};

} // namespace c10d
74 changes: 69 additions & 5 deletions include/torch_ucc_comm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,49 @@
#define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET)
#define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET)

#define TORCH_UCC_TIMEOUT_AM_ID 0

#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))

#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))

#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
} while (0)

#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)

#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
} else { \
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
} \
} while (0)

#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
} while (0)

#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
(_ucp_tag_mask) = (uint64_t)-1; \
} while (0)

namespace c10d {

// Macro to throw on a non-successful UCC return value.
Expand Down Expand Up @@ -87,7 +130,7 @@ namespace c10d {
LOG(ERROR) << logger->getLogPrefix(_phase) << "[ERROR] " << _msg;
#define TORCH_UCC_LOG_INFO(_phase, _msg) \
LOG(INFO) << logger->getLogPrefix(_phase) << "[INFO] " << _msg;
#define TORCH_UCC_LOG_DEBUG_phase, _msg) \
#define TORCH_UCC_LOG_DEBUG(_phase, _msg) \
VLOG(1) << logger->getLogPrefix(_phase) << "[DEBUG] " << _msg;

enum torch_ucc_phase_t {
Expand All @@ -96,7 +139,8 @@ enum torch_ucc_phase_t {
TORCH_UCC_READY,
TORCH_UCC_COLL_POST,
TORCH_UCC_COLL_PROGRESS,
TORCH_UCC_FINALIZE
TORCH_UCC_FINALIZE,
TORCH_UCC_COMM_CHECK
};

const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
Expand All @@ -106,6 +150,7 @@ const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
{TORCH_UCC_COLL_POST, "COLL_POST"},
{TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"},
{TORCH_UCC_FINALIZE, "FINALIZE"},
{TORCH_UCC_COMM_CHECK, "COMM_CHECK"}
};

class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
Expand All @@ -124,6 +169,24 @@ class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN;
};

enum torch_ucc_rank_state_t {
TORCH_UCC_RANK_STATE_NOT_RESPONDIG,
TORCH_UCC_RANK_STATE_COLLECTIVE_NOT_POSTED,
TORCH_UCC_RANK_STATE_COLLECTIVE_INPROGRESS,
TORCH_UCC_RANK_STATE_COLLECTIVE_TIMEOUT,
TORCH_UCC_RANK_STATE_DEVICE_ERROR,
TORCH_UCC_RANK_STATE_COLLECTIVE_DONE,
TORCH_UCC_RANK_STATE_UNKNOWN
};

const char *torch_ucc_rank_state_string(torch_ucc_rank_state_t state);

struct torch_ucc_timeout_desc_t {
int rank;
int comm_id;
uint64_t seq_num;
};

struct torch_ucc_oob_coll_info_t {
c10::intrusive_ptr<Store> store;
uint32_t comm_id;
Expand Down Expand Up @@ -152,11 +215,12 @@ class CommUCX : public CommBase {
ucp_worker_h worker{nullptr};

public:
void progress() override;
void free_request(ucc_coll_req_h request) override;
CommUCX(
int comm_size,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
void progress() override;
void free_request(ucc_coll_req_h request) override;
void set_am_recv_handler(const ucp_am_handler_param_t *params);
~CommUCX();
};

Expand All @@ -166,10 +230,10 @@ class CommUCC : public CommBase {
ucc_context_h context{nullptr};

public:
void progress() override;
CommUCC(
torch_ucc_oob_coll_info_t* oob_info,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
void progress() override;
void free_request(ucc_coll_req_h request) override;
~CommUCC();
};
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
else:
print("CUDA support is enabled")
plugin_libraries.append("cuda")
plugin_compile_args.append("-DUSE_CUDA")
module = cpp_extension.CUDAExtension(
name = "torch_ucc",
Expand Down
Loading

0 comments on commit 8c5abce

Please sign in to comment.