Skip to content

Commit

Permalink
add context parallel group init to mp init
Browse files Browse the repository at this point in the history
  • Loading branch information
amyyang committed Apr 3, 2024
1 parent 9a173bf commit 66d75ce
Showing 1 changed file with 80 additions and 28 deletions.
108 changes: 80 additions & 28 deletions fairscale/nn/model_parallel/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,21 @@
_DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None

_PIPELINE_PARALLEL_RANKS = None

_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_GROUP_RANKS = None


def initialize_model_parallel(
model_parallel_size_: int,
model_parallel_size: int,
context_parallel_size: int = 1,
pipeline_length: int = 1,
*,
model_parallel_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None,
ddp_backend: Optional[str] = None
ddp_backend: Optional[str] = None,
) -> None:
"""
Initialize model data parallel groups.
Expand All @@ -67,19 +71,21 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size_, world_size))
model_parallel_size = int(min(model_parallel_size, world_size))
ensure_divisibility(world_size, model_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length)
ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)
rank = torch.distributed.get_rank()

data_parallel_size = int(world_size / (model_parallel_size * pipeline_length))
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length * context_parallel_size))

if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_))
print("> initializing ddp with size {}".format(data_parallel_size))
print("> initializing model parallel with size {}".format(model_parallel_size))
print("> initializing context parallel with size {}".format(context_parallel_size))
print("> initializing pipeline with size {}".format(pipeline_length))
print("> initializing ddp with size {}".format(data_parallel_size))

groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size)
groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, context_parallel_size, model_parallel_size)

found = torch.where(groups == rank)
assert all(len(x) == 1 for x in found)
Expand All @@ -88,41 +94,81 @@ def initialize_model_parallel(
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for j in range(pipeline_length):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend)
if j == found[1] and k == found[2]:
_DATA_PARALLEL_GROUP = group
for i in range(pipeline_length):
for j in range(context_parallel_size):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, i, j, k].tolist(), backend=ddp_backend)
if i == found[1] and j == found[2] and k == found[3]:
_DATA_PARALLEL_GROUP = group


# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
assert _MODEL_PARALLEL_GROUP is None, "Model parallel group is already initialized"
for i in range(data_parallel_size):
for j in range(pipeline_length):
group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1]:
_MODEL_PARALLEL_GROUP = group
for k in range(context_parallel_size):
group = torch.distributed.new_group(groups[i, j, k, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1] and k == found[2]:
_MODEL_PARALLEL_GROUP = group


# Build the pipeline parallel groups.
global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
assert _PIPELINE_PARALLEL_GROUP is None, "Pipeline parallel group is already initialized"
for i in range(data_parallel_size):
for j in range(context_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, j, k].tolist()
group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and j == found[2] and k == found[3]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks


# Build the context parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GROUP_RANKS

assert (
_CONTEXT_PARALLEL_GROUP is None
), "Context parallelism is already initialized."
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
for j in range(pipeline_length):
for k in range(model_parallel_size):
ranks = groups[i, j, :, k].tolist()
group = torch.distributed.new_group(ranks, backend=cp_backend)
if i == found[0] and j == found[1] and k == found[3]:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GROUP_RANKS = ranks


def model_parallel_is_initialized() -> bool:
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None:
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None or _CONTEXT_PARALLEL_GROUP is None:
return False
return True


def get_context_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the context parallel group the caller rank belongs to."""
assert (
_CONTEXT_PARALLEL_GROUP is not None
), "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP


def get_context_parallel_world_size() -> int:
"""Return world size for the context parallel group."""
return torch.distributed.get_world_size(group=get_context_parallel_group())


def get_context_parallel_rank() -> int:
"""Return my rank for the context parallel group."""
return torch.distributed.get_rank(group=get_context_parallel_group())


def get_model_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
Expand Down Expand Up @@ -179,10 +225,16 @@ def destroy_model_parallel() -> None:
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None

global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None

global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None

global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None

global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None
global _CONTEXT_PARALLEL_GROUP_RANKS
_CONTEXT_PARALLEL_GROUP_RANKS = None

0 comments on commit 66d75ce

Please sign in to comment.