Skip to content

Commit

Permalink
add get_cp_ranks to model_parallel initialize (#1176)
Browse files Browse the repository at this point in the history
Co-authored-by: amyyang <[email protected]>
  • Loading branch information
amylittleyang and amyyang authored Apr 6, 2024
1 parent 0af41ae commit 8fb39b2
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions fairscale/nn/model_parallel/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def get_context_parallel_group() -> torch.distributed.ProcessGroup:
return _CONTEXT_PARALLEL_GROUP


def get_context_parallel_ranks() -> List[int]:
"""Return context parallel ranks for the context parallel group."""
assert _CONTEXT_PARALLEL_GROUP_RANKS is not None, "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP_RANKS


def get_context_parallel_world_size() -> int:
"""Return world size for the context parallel group."""
return torch.distributed.get_world_size(group=get_context_parallel_group())
Expand Down

0 comments on commit 8fb39b2

Please sign in to comment.