Skip to content

Commit

Permalink
sync fbcode cp pg initialize
Browse files Browse the repository at this point in the history
  • Loading branch information
amyyang committed Apr 12, 2024
1 parent 8fb39b2 commit cd01f7d
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions fairscale/nn/model_parallel/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
# limitations under the License.


"""Model and data parallel groups."""

from typing import List, Optional

import torch
Expand All @@ -41,13 +39,13 @@


def initialize_model_parallel(
model_parallel_size: int,
context_parallel_size: int = 1,
model_parallel_size_: int,
pipeline_length: int = 1,
context_parallel_size: int = 1,
*,
model_parallel_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
ddp_backend: Optional[str] = None,
) -> None:
"""
Expand All @@ -67,11 +65,28 @@ def initialize_model_parallel(
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
process groups initialized in the order of MP, CP, PP, DP.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups
and 8 data-parallel groups as:
when alternate_pp_config = False,
8 data_parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 context-parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 pipeline model-parallel groups:
[g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size, world_size))
model_parallel_size = int(min(model_parallel_size_, world_size))
ensure_divisibility(world_size, model_parallel_size)
ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)
Expand Down

0 comments on commit cd01f7d

Please sign in to comment.