Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consensus protocol #39

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 28 additions & 45 deletions include/torch_ucc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,47 +32,6 @@ namespace c10d {

#define TORCH_UCC_DEVICE_NOT_SET -2

#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))

#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))

#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
} while (0)

#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)

#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
} else { \
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
} \
} while (0)

#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
} while (0)

#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
(_ucp_tag_mask) = (uint64_t)-1; \
} while (0)

#ifdef USE_CUDA
#define SAVE_TENSORS(_TENSORS, _DATA) \
do { \
Expand Down Expand Up @@ -145,15 +104,21 @@ class ProcessGroupUCC : public ProcessGroup {
public:
ProgressEntry(
CommBase* comm,
ucc_coll_req_h request)
: status_(UCC_INPROGRESS), comm_(comm), request_(request) {}
ucc_coll_req_h request,
uint64_t seq_num)
: status_(UCC_INPROGRESS), comm_(comm), request_(request),
seq_num_(seq_num) {}
void finalize(std::exception_ptr eptr = nullptr);
ucc_status_t status_;
CommBase* comm_;
ucc_coll_req_h request_;
uint64_t seq_num_;
std::unique_ptr<WorkData> data;
c10::intrusive_ptr<c10::ivalue::Future> future_;
std::exception_ptr eptr_;
std::vector<ucp_ep_h> *eps;
int rank;
int comm_id;
};

class WorkUCC : public ProcessGroup::Work {
Expand Down Expand Up @@ -311,9 +276,11 @@ template <typename PreProcess, typename PostProcess>

class CommPG {
c10::intrusive_ptr<ProcessGroupUCCLogger> logger;
std::vector<torch_ucc_rank_state_t> comm_state;
std::shared_ptr<torch_ucc_oob_coll_info_t> oob;
CommUCX ucx_comm;
CommUCC ucc_comm;
uint64_t seq_num;
c10::DeviceIndex device_index;
std::mutex mutex;
std::thread progress_thread;
Expand All @@ -323,6 +290,8 @@ class CommPG {
bool stop_progress_loop;
bool collective_inprogress;

void check_communicator_status(int my_rank, int comm_id, uint64_t seq_num,
std::vector<ucp_ep_h> *eps);
public:
c10::DeviceIndex cuda_device_index;
CommPG(const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger,
Expand Down Expand Up @@ -355,14 +324,21 @@ class CommPG {
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team,
ucc_ee_h ee);
ucc_ee_h ee,
std::vector<ucp_ep_h> *eps,
int rank,
int comm_id);

#endif

void enqueue_collective(
std::unique_ptr<ProcessGroupUCC::WorkData> data,
c10::intrusive_ptr<ProcessGroupUCC::WorkUCC> work,
ucc_coll_args_t& coll,
ucc_team_h team);
ucc_team_h team,
std::vector<ucp_ep_h> *eps,
int rank,
int comm_id);

static std::shared_ptr<CommPG> get_comm(
uint32_t& id,
Expand All @@ -385,6 +361,13 @@ class CommPG {
size_t size,
ucp_tag_t ucp_tag,
ucp_tag_t ucp_tag_mask);
friend ucs_status_t torch_ucc_timeout_am_cb(
void *arg,
const void *header,
size_t header_length,
void *data,
size_t length,
const ucp_am_recv_param_t *param);
};

} // namespace c10d
68 changes: 66 additions & 2 deletions include/torch_ucc_comm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,49 @@
#define TORCH_UCX_TAG_MASK (TORCH_UCX_MAX_TAG << TORCH_UCX_TAG_BITS_OFFSET)
#define TORCH_UCX_OOB_MASK (TORCH_UCX_MAX_OOB << TORCH_UCX_OOB_BITS_OFFSET)

#define TORCH_UCC_TIMEOUT_AM_ID 0

#define TORCH_UCX_MAKE_P2P_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_TAG_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_comm)) << TORCH_UCX_COMM_BITS_OFFSET))

#define TORCH_UCX_MAKE_OOB_TAG(_tag, _rank, _comm) \
((((uint64_t)(_tag)) << TORCH_UCX_OOB_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_RANK_BITS_OFFSET) | \
(((uint64_t)(_rank)) << TORCH_UCX_COMM_BITS_OFFSET))

#define TORCH_UCX_MAKE_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
} while (0)

#define TORCH_UCX_ANY_SOURCE (TORCH_UCX_MAX_RANK - 1)
#define TORCH_UCX_ANY_SOURCE_MASK (~TORCH_UCX_RANK_MASK)
#define TORCH_UCX_SPECIFIC_SOURCE_MASK ((uint64_t)-1)

#define TORCH_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_P2P_TAG((_tag), (_rank), (_comm)); \
if ((_rank) == TORCH_UCX_ANY_SOURCE) { \
(_ucp_tag_mask) = TORCH_UCX_ANY_SOURCE_MASK; \
} else { \
(_ucp_tag_mask) = TORCH_UCX_SPECIFIC_SOURCE_MASK; \
} \
} while (0)

#define TORCH_UCX_MAKE_OOB_SEND_TAG(_ucp_tag, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
} while (0)

#define TORCH_UCX_MAKE_OOB_RECV_TAG( \
_ucp_tag, _ucp_tag_mask, _tag, _rank, _comm) \
do { \
(_ucp_tag) = TORCH_UCX_MAKE_OOB_TAG((_tag), (_rank), (_comm)); \
(_ucp_tag_mask) = (uint64_t)-1; \
} while (0)

namespace c10d {

// Macro to throw on a non-successful UCC return value.
Expand Down Expand Up @@ -99,6 +142,7 @@ enum torch_ucc_phase_t {
TORCH_UCC_COLL_POST,
TORCH_UCC_COLL_PROGRESS,
TORCH_UCC_FINALIZE,
TORCH_UCC_COMM_CHECK
};

const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
Expand All @@ -108,6 +152,7 @@ const std::map<torch_ucc_phase_t, std::string> ucc_phase_map = {
{TORCH_UCC_COLL_POST, "COLL_POST"},
{TORCH_UCC_COLL_PROGRESS, "COLL_PROGRESS"},
{TORCH_UCC_FINALIZE, "FINALIZE"},
{TORCH_UCC_COMM_CHECK, "COMM_CHECK"}
};

class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
Expand All @@ -126,6 +171,24 @@ class TORCH_API ProcessGroupUCCLogger : public torch::CustomClassHolder {
torch_ucc_phase_t local_phase = TORCH_UCC_UNKNOWN;
};

enum torch_ucc_rank_state_t {
TORCH_UCC_RANK_STATE_NOT_RESPONDIG,
TORCH_UCC_RANK_STATE_COLLECTIVE_NOT_POSTED,
TORCH_UCC_RANK_STATE_COLLECTIVE_INPROGRESS,
TORCH_UCC_RANK_STATE_COLLECTIVE_TIMEOUT,
TORCH_UCC_RANK_STATE_DEVICE_ERROR,
TORCH_UCC_RANK_STATE_COLLECTIVE_DONE,
TORCH_UCC_RANK_STATE_UNKNOWN
};

const char *torch_ucc_rank_state_string(torch_ucc_rank_state_t state);

struct torch_ucc_timeout_desc_t {
int rank;
int comm_id;
uint64_t seq_num;
};

struct torch_ucc_oob_coll_info_t {
c10::intrusive_ptr<Store> store;
uint32_t comm_id;
Expand Down Expand Up @@ -154,11 +217,12 @@ class CommUCX : public CommBase {
ucp_worker_h worker{nullptr};

public:
void progress() override;
void free_request(ucc_coll_req_h request) override;
CommUCX(
int comm_size,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger);
void progress() override;
void free_request(ucc_coll_req_h request) override;
void set_am_recv_handler(const ucp_am_handler_param_t *params);
~CommUCX();
};

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
)
else:
print("CUDA support is enabled")
plugin_libraries.append("cuda")
plugin_compile_args.append("-DUSE_CUDA")
module = cpp_extension.CUDAExtension(
name = "torch_ucc",
Expand Down
Loading