Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add context parallel group init to mp init #1174

Merged
merged 1 commit into from
Apr 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 80 additions & 28 deletions fairscale/nn/model_parallel/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,21 @@
_DATA_PARALLEL_GROUP = None
# Pipeline parallel group that the current rank belongs to.
_PIPELINE_PARALLEL_GROUP = None

_PIPELINE_PARALLEL_RANKS = None

_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_GROUP_RANKS = None


def initialize_model_parallel(
model_parallel_size_: int,
model_parallel_size: int,
context_parallel_size: int = 1,
pipeline_length: int = 1,
*,
model_parallel_backend: Optional[str] = None,
cp_backend: Optional[str] = None,
pipeline_backend: Optional[str] = None,
ddp_backend: Optional[str] = None
ddp_backend: Optional[str] = None,
) -> None:
"""
Initialize model data parallel groups.
Expand All @@ -67,19 +71,21 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = int(min(model_parallel_size_, world_size))
model_parallel_size = int(min(model_parallel_size, world_size))
ensure_divisibility(world_size, model_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length)
ensure_divisibility(world_size, context_parallel_size)
ensure_divisibility(world_size, model_parallel_size * pipeline_length * context_parallel_size)
rank = torch.distributed.get_rank()

data_parallel_size = int(world_size / (model_parallel_size * pipeline_length))
data_parallel_size = int(world_size / (model_parallel_size * pipeline_length * context_parallel_size))

if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_))
print("> initializing ddp with size {}".format(data_parallel_size))
print("> initializing model parallel with size {}".format(model_parallel_size))
print("> initializing context parallel with size {}".format(context_parallel_size))
print("> initializing pipeline with size {}".format(pipeline_length))
print("> initializing ddp with size {}".format(data_parallel_size))

groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size)
groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, context_parallel_size, model_parallel_size)

found = torch.where(groups == rank)
assert all(len(x) == 1 for x in found)
Expand All @@ -88,41 +94,81 @@ def initialize_model_parallel(
# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for j in range(pipeline_length):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend)
if j == found[1] and k == found[2]:
_DATA_PARALLEL_GROUP = group
for i in range(pipeline_length):
for j in range(context_parallel_size):
for k in range(model_parallel_size):
group = torch.distributed.new_group(groups[:, i, j, k].tolist(), backend=ddp_backend)
if i == found[1] and j == found[2] and k == found[3]:
_DATA_PARALLEL_GROUP = group


# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
assert _MODEL_PARALLEL_GROUP is None, "Model parallel group is already initialized"
for i in range(data_parallel_size):
for j in range(pipeline_length):
group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1]:
_MODEL_PARALLEL_GROUP = group
for k in range(context_parallel_size):
group = torch.distributed.new_group(groups[i, j, k, :].tolist(), backend=model_parallel_backend)
if i == found[0] and j == found[1] and k == found[2]:
_MODEL_PARALLEL_GROUP = group


# Build the pipeline parallel groups.
global _PIPELINE_PARALLEL_GROUP
assert _PIPELINE_PARALLEL_GROUP is None, "model parallel group is already initialized"
global _PIPELINE_PARALLEL_RANKS
assert _PIPELINE_PARALLEL_RANKS is None, "model parallel group is already initialized"
assert _PIPELINE_PARALLEL_GROUP is None, "Pipeline parallel group is already initialized"
for i in range(data_parallel_size):
for j in range(context_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, j, k].tolist()
group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and j == found[2] and k == found[3]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks


# Build the context parallel groups.
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_GROUP_RANKS

assert (
_CONTEXT_PARALLEL_GROUP is None
), "Context parallelism is already initialized."
for i in range(data_parallel_size):
for k in range(model_parallel_size):
ranks = groups[i, :, k].tolist()
group = torch.distributed.new_group(ranks, backend=pipeline_backend)
if i == found[0] and k == found[2]:
_PIPELINE_PARALLEL_GROUP = group
_PIPELINE_PARALLEL_RANKS = ranks
for j in range(pipeline_length):
for k in range(model_parallel_size):
ranks = groups[i, j, :, k].tolist()
group = torch.distributed.new_group(ranks, backend=cp_backend)
if i == found[0] and j == found[1] and k == found[3]:
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_GROUP_RANKS = ranks


def model_parallel_is_initialized() -> bool:
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None:
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None or _PIPELINE_PARALLEL_GROUP is None or _CONTEXT_PARALLEL_GROUP is None:
return False
return True


def get_context_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the context parallel group the caller rank belongs to."""
assert (
_CONTEXT_PARALLEL_GROUP is not None
), "context parallel group is not initialized"
return _CONTEXT_PARALLEL_GROUP


def get_context_parallel_world_size() -> int:
"""Return world size for the context parallel group."""
return torch.distributed.get_world_size(group=get_context_parallel_group())


def get_context_parallel_rank() -> int:
"""Return my rank for the context parallel group."""
return torch.distributed.get_rank(group=get_context_parallel_group())
Comment on lines +167 to +169

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like line 161-165 is the same function - maybe remove this?



def get_model_parallel_group() -> torch.distributed.ProcessGroup:
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
Expand Down Expand Up @@ -179,10 +225,16 @@ def destroy_model_parallel() -> None:
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None

global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None

global _PIPELINE_PARALLEL_GROUP
_PIPELINE_PARALLEL_GROUP = None

global _PIPELINE_PARALLEL_RANKS
_PIPELINE_PARALLEL_RANKS = None

global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None
jasonjk-park marked this conversation as resolved.
Show resolved Hide resolved
global _CONTEXT_PARALLEL_GROUP_RANKS
_CONTEXT_PARALLEL_GROUP_RANKS = None
Loading