From 66d75ce96d57f935f26969f37d0639bf673523b2 Mon Sep 17 00:00:00 2001 From: amyyang Date: Tue, 2 Apr 2024 17:55:43 -0700 Subject: [PATCH] add context parallel group init to mp init --- fairscale/nn/model_parallel/initialize.py | 108 ++++++++++++++++------ 1 file changed, 80 insertions(+), 28 deletions(-) diff --git a/fairscale/nn/model_parallel/initialize.py b/fairscale/nn/model_parallel/initialize.py index 9765008e5..026eb7892 100644 --- a/fairscale/nn/model_parallel/initialize.py +++ b/fairscale/nn/model_parallel/initialize.py @@ -34,17 +34,21 @@ _DATA_PARALLEL_GROUP = None # Pipeline parallel group that the current rank belongs to. _PIPELINE_PARALLEL_GROUP = None - _PIPELINE_PARALLEL_RANKS = None +_CONTEXT_PARALLEL_GROUP = None +_CONTEXT_PARALLEL_GROUP_RANKS = None + def initialize_model_parallel( - model_parallel_size_: int, + model_parallel_size: int, + context_parallel_size: int = 1, pipeline_length: int = 1, *, model_parallel_backend: Optional[str] = None, + cp_backend: Optional[str] = None, pipeline_backend: Optional[str] = None, - ddp_backend: Optional[str] = None + ddp_backend: Optional[str] = None, ) -> None: """ Initialize model data parallel groups. @@ -67,19 +71,21 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() - model_parallel_size = int(min(model_parallel_size_, world_size)) + model_parallel_size = int(min(model_parallel_size, world_size)) ensure_divisibility(world_size, model_parallel_size) - ensure_divisibility(world_size, model_parallel_size * pipeline_length) + ensure_divisibility(world_size, context_parallel_size) + ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size) rank = torch.distributed.get_rank() - data_parallel_size = int(world_size / (model_parallel_size * pipeline_length)) + data_parallel_size = int(world_size / (model_parallel_size * pipeline_length * context_parallel_size)) if torch.distributed.get_rank() == 0: - print("> initializing model parallel with size {}".format(model_parallel_size_)) - print("> initializing ddp with size {}".format(data_parallel_size)) + print("> initializing model parallel with size {}".format(model_parallel_size)) + print("> initializing context parallel with size {}".format(context_parallel_size)) print("> initializing pipeline with size {}".format(pipeline_length)) + print("> initializing ddp with size {}".format(data_parallel_size)) - groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size) + groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, context_parallel_size, model_parallel_size) found = torch.where(groups == rank) assert all(len(x) == 1 for x in found) @@ -88,41 +94,81 @@ def initialize_model_parallel( # Build the data parallel groups. global _DATA_PARALLEL_GROUP assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" - for j in range(pipeline_length): - for k in range(model_parallel_size): - group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend) - if j == found[1] and k == found[2]: - _DATA_PARALLEL_GROUP = group + for i in range(pipeline_length): + for j in range(context_parallel_size): + for k in range(model_parallel_size): + group = torch.distributed.new_group(groups[:, i, j, k].tolist(), backend=ddp_backend) + if i == found[1] and j == found[2] and k == found[3]: + _DATA_PARALLEL_GROUP = group + # Build the model parallel groups. global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" + assert _MODEL_PARALLEL_GROUP is None, "Model parallel group is already initialized" for i in range(data_parallel_size): for j in range(pipeline_length): - group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend) - if i == found[0] and j == found[1]: - _MODEL_PARALLEL_GROUP = group + for k in range(context_parallel_size): + group = torch.distributed.new_group(groups[i, j, k, :].tolist(), backend=model_parallel_backend) + if i == found[0] and j == found[1] and k == found[2]: + _MODEL_PARALLEL_GROUP = group + + # Build the pipeline parallel groups. global _PIPELINE_PARALLEL_GROUP - assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized" global _PIPELINE_PARALLEL_RANKS - assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized" + assert _PIPELINE_PARALLEL_GROUP is None, "Pipeline parallel group is already initialized" + for i in range(data_parallel_size): + for j in range(context_parallel_size): + for k in range(model_parallel_size): + ranks = groups[i, :, j, k].tolist() + group = torch.distributed.new_group(ranks, backend=pipeline_backend) + if i == found[0] and j == found[2] and k == found[3]: + _PIPELINE_PARALLEL_GROUP = group + _PIPELINE_PARALLEL_RANKS = ranks + + + # Build the context parallel groups. + global _CONTEXT_PARALLEL_GROUP + global _CONTEXT_PARALLEL_GROUP_RANKS + + assert ( + _CONTEXT_PARALLEL_GROUP is None + ), "Context parallelism is already initialized." for i in range(data_parallel_size): - for k in range(model_parallel_size): - ranks = groups[i, :, k].tolist() - group = torch.distributed.new_group(ranks, backend=pipeline_backend) - if i == found[0] and k == found[2]: - _PIPELINE_PARALLEL_GROUP = group - _PIPELINE_PARALLEL_RANKS = ranks + for j in range(pipeline_length): + for k in range(model_parallel_size): + ranks = groups[i, j, :, k].tolist() + group = torch.distributed.new_group(ranks, backend=cp_backend) + if i == found[0] and j == found[1] and k == found[3]: + _CONTEXT_PARALLEL_GROUP = group + _CONTEXT_PARALLEL_GROUP_RANKS = ranks def model_parallel_is_initialized() -> bool: """Check if model and data parallel groups are initialized.""" - if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None: + if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None or _CONTEXT_PARALLEL_GROUP is None: return False return True +def get_context_parallel_group() -> torch.distributed.ProcessGroup: + """Get the context parallel group the caller rank belongs to.""" + assert ( + _CONTEXT_PARALLEL_GROUP is not None + ), "context parallel group is not initialized" + return _CONTEXT_PARALLEL_GROUP + + +def get_context_parallel_world_size() -> int: + """Return world size for the context parallel group.""" + return torch.distributed.get_world_size(group=get_context_parallel_group()) + + +def get_context_parallel_rank() -> int: + """Return my rank for the context parallel group.""" + return torch.distributed.get_rank(group=get_context_parallel_group()) + + def get_model_parallel_group() -> torch.distributed.ProcessGroup: """Get the model parallel group the caller rank belongs to.""" assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" @@ -179,10 +225,16 @@ def destroy_model_parallel() -> None: """Set the groups to none.""" global _MODEL_PARALLEL_GROUP _MODEL_PARALLEL_GROUP = None + global _DATA_PARALLEL_GROUP _DATA_PARALLEL_GROUP = None + global _PIPELINE_PARALLEL_GROUP _PIPELINE_PARALLEL_GROUP = None - global _PIPELINE_PARALLEL_RANKS _PIPELINE_PARALLEL_RANKS = None + + global _CONTEXT_PARALLEL_GROUP + _CONTEXT_PARALLEL_GROUP = None + global _CONTEXT_PARALLEL_GROUP_RANKS + _CONTEXT_PARALLEL_GROUP_RANKS = None