Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync fbcode cp pg initialize #1177

Merged
merged 1 commit into from
Apr 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions fairscale/nn/model_parallel/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
# limitations under the License.


"""Model and data parallel groups."""

from typing import List, Optional

import torch
Expand All @@ -41,13 +39,13 @@


def initialize_model_parallel(
model_parallel_size: int,
context_parallel_size: int = 1,
model_parallel_size_: int,
pipeline_length: int = 1,
context_parallel_size: int = 1,
*,
model_parallel_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
ddp_backend: Optional[str] = None,
) -> None:
"""
Expand All @@ -67,11 +65,28 @@ def initialize_model_parallel(
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.

process groups initialized in the order of MP, CP, PP, DP.

Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, 2 GPUs to parallelize context(seq len), and 2 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 8 context-parallel group, 8 pipeline model-parallel groups
and 8 data-parallel groups as:
when alternate_pp_config = False,
8 data_parallel groups:
[g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
8 context-parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 pipeline model-parallel groups:
[g0, g8], [g1, g9], [g2, g10], [g3, g11], [g4, g12], [g5, g13], [g6, g16], [g7, g15]
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size, world_size))
model_parallel_size = int(min(model_parallel_size_, world_size))
ensure_divisibility(world_size, model_parallel_size)
ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)
Expand Down
Loading