diff --git a/fairscale/nn/model_parallel/initialize.py b/fairscale/nn/model_parallel/initialize.py index 035ed5161..1c3cb68f5 100644 --- a/fairscale/nn/model_parallel/initialize.py +++ b/fairscale/nn/model_parallel/initialize.py @@ -20,8 +20,6 @@ # limitations under the License. -"""Model and data parallel groups.""" - from typing import List, Optional import torch @@ -41,13 +39,13 @@ def initialize_model_parallel( - model_parallel_size: int, - context_parallel_size: int = 1, + model_parallel_size_: int, pipeline_length: int = 1, + context_parallel_size: int = 1, *, model_parallel_backend: Optional[str] = None, - cp_backend: Optional[str] = None, pipeline_backend: Optional[str] = None, + cp_backend: Optional[str] = None, ddp_backend: Optional[str] = None, ) -> None: """ @@ -67,11 +65,28 @@ def initialize_model_parallel( are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. + + process groups initialized in the order of MP, CP, PP, DP. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize + the model pipeline. The present function will + create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups + and 8 data-parallel groups as: + when alternate_pp_config = False, + 8 data_parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15] + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 8 context-parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] + 8 pipeline model-parallel groups: + [g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15] """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() - model_parallel_size = int(min(model_parallel_size, world_size)) + model_parallel_size = int(min(model_parallel_size_, world_size)) ensure_divisibility(world_size, model_parallel_size) ensure_divisibility(world_size, context_parallel_size) ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)