Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic Model Parallelism Through FX #1933

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5e39787
WIP
zhenglongjiepheonix Jun 3, 2024
7a5d394
add dist ops
zhenglongjiepheonix Jun 11, 2024
7e15d26
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jun 11, 2024
98e5846
add index propagation
zhenglongjiepheonix Jun 15, 2024
2036dbb
support tp for linears
zhenglongjiepheonix Jul 1, 2024
34fffe8
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 1, 2024
0876f5d
add embedding & weight tie
zhenglongjiepheonix Jul 8, 2024
87e66fb
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 8, 2024
ae6d9d2
address comments
zhenglongjiepheonix Jul 8, 2024
455c0c7
lint
zhenglongjiepheonix Jul 8, 2024
27a9bb8
fix
zhenglongjiepheonix Jul 12, 2024
473388b
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 12, 2024
0512b23
fix
zhenglongjiepheonix Jul 12, 2024
8ec6727
debug
zhenglongjiepheonix Jul 13, 2024
5095f1e
fix
zhenglongjiepheonix Jul 13, 2024
f6ebfc0
fix tests
zhenglongjiepheonix Jul 15, 2024
e71e5ea
add experimental API
zhenglongjiepheonix Jul 16, 2024
eb2a7a6
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 16, 2024
779c77d
nit
zhenglongjiepheonix Jul 16, 2024
e09df2a
fix api
zhenglongjiepheonix Jul 17, 2024
22fe1a3
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 17, 2024
9fd29d1
fix api
zhenglongjiepheonix Jul 18, 2024
01cfc25
format
zhenglongjiepheonix Jul 18, 2024
8c16267
clean tests
zhenglongjiepheonix Jul 18, 2024
8ef00e0
fix weight_map
zhenglongjiepheonix Jul 18, 2024
6ef2081
add weights loading
zhenglongjiepheonix Jul 22, 2024
2c561d3
format
zhenglongjiepheonix Jul 22, 2024
fc96b6f
fix
zhenglongjiepheonix Jul 22, 2024
8d2cabb
fix
zhenglongjiepheonix Jul 23, 2024
c9c7571
Merge remote-tracking branch 'upstream/main' into longjie/add_automat…
zhenglongjiepheonix Jul 23, 2024
97e6431
enable tests
zhenglongjiepheonix Jul 23, 2024
efd5d28
address comments
zhenglongjiepheonix Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions optimum/fx/parallelization/__init__.py
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
from torch.fx import GraphModule
from typing import List
from .core import ParallelExecutionCtx, Config
from .passes import build_parallel_pass_pipeline


def parallelize_backend(graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config):
ctx.example_inputs = example_inputs
pass_pipeline = build_parallel_pass_pipeline()
graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config)
ctx.compile_times += 1
return graph_module
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
109 changes: 109 additions & 0 deletions optimum/fx/parallelization/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from dataclasses import dataclass, field
from typing import List, Any, List, Dict, Callable
import torch
import torch.nn as nn
import torch.distributed as dist
from functools import partial

class HashableSlice:
def __init__(self, start : int, stop : int, step : int) -> None:
self.start = start
self.stop = stop
self.step = step

def __hash__(self) -> int:
return hash(f'{self.start},{self.stop},{self.step}')

def __eq__(self, value: object) -> bool:
return isinstance(value, HashableSlice) and self.start == value.start and \
self.stop == value.stop and self.step == value.step

def to_slice(self) -> None:
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
return slice(self.start, self.stop, self.step)


@dataclass
class ParameterMeta:
# parameter name
source : str = None
# which axis to index
dim : int = None
# index to slice the tensor
index : slice = None
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class ParameterMapping:
id : int = None
meta : ParameterMeta = None


@dataclass
class ParallelParameterMapping(ParameterMapping):
# the axis being parallelized
parallel_dim : int = None
# for multi-source parameter mapping
mapping : Dict[HashableSlice, ParameterMeta] = field(default_factory=dict)
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class ParallelExecutionCtx:
"""
Parallel execution context which contains runtime information.

zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
- example_inputs
A list of tensors which are used as example inputs for graphs captured by dynamo.

- parallel_layer_cache
Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts.
Note that we will build the cache in the first compilation process, and for recompilations
later on, we will directly replace the modules with their parallel counterparts in the cache,
because we have to make sure we don't initiate new parameters and replace original ones when
recompilation happens in training process.

- parameter_mapping
Mapping between parameter ids and their correponding names in the original module. Note
that it changes as we create new parameters to replace original ones in the first compilation
process. It's useful because dynamo flattens the graph(which invalidates the parameter name
hierarchy) but the original parameters are kept.

- weight_map
Mapping between parameter names and their locations on disk, useful when loading weights
from disk.

- tp_group
Tensor parallel process group the current process belongs to.

- compile_times
Number of compilation times happened during the whole process.

- current_device
Device correpsonding to the current process.
"""
example_inputs : List[Any] = field(default_factory=list)
parallel_layer_cache : Dict[int, nn.Module] = field(default_factory=dict)
parameter_mapping : Dict[int, ParameterMapping] = field(default_factory=dict)
weight_map : Dict[str, str] = field(default_factory=dict)
tp_group : dist.ProcessGroup = None
compile_times : int = 0
current_device : torch.device = None
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class Config:
"""
Static config which contains instructions which do not change in runtime.

zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
- lint_and_recompile
Whether to run graph linting and module recompilation after every pass.

- clean_markers_after_all_passes
Whether to clean markers of analytical passes after all passes have run.

- weight_init_fn
Initialization function of weights in `nn.Linear` and `nn.Embedding` layers,
if not provided weights loading path.
"""
lint_and_recompile : bool = True
clean_markers_after_all_passes : bool = True
weight_init_fn : Callable = partial(nn.init.normal_, std=0.02)
7 changes: 7 additions & 0 deletions optimum/fx/parallelization/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .dist_ops import (
differentiable_all_gather,
differentiable_identity,
differentiable_all_reduce_sum,
differentiable_scatter,
scatter,
)
113 changes: 113 additions & 0 deletions optimum/fx/parallelization/distributed/dist_ops.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file seems more related to the parallel layers. Hopefully at some point we could use existing backends instead.
Like nanotron or megatron etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be great ! maybe even the torch native parallelism layers.

Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
import torch.distributed as dist

def all_reduce(group: dist.ProcessGroup, tensor : torch.Tensor) -> torch.Tensor:
world_size = dist.get_world_size(group)
if world_size == 1:
return tensor

dist.all_reduce(tensor, group=group)
return tensor

def all_gather(group: dist.ProcessGroup, tensor: torch.Tensor, gather_dim: int = -1) -> torch.Tensor:
world_size = dist.get_world_size(group)
if world_size == 1:
return tensor
rank = dist.get_rank(group = group)

tensor = tensor.contiguous()
tensors = [torch.empty_like(tensor) for _ in range(world_size)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

tensors[rank] = tensor

dist.all_gather(tensors, tensor, group=group)
return torch.cat(tensors, dim=gather_dim)

def split(group: dist.ProcessGroup, tensor: torch.Tensor, split_dim: int = -1) -> torch.Tensor:
world_size = dist.get_world_size(group)
if world_size == 1:
return tensor

rank = dist.get_rank(group)
size = tensor.size()
assert size[split_dim] % world_size == 0
zhenglongjiepheonix marked this conversation as resolved.
Show resolved Hide resolved
tensors = torch.split(tensor, size[split_dim] // world_size, dim = split_dim)
tensor = tensors[rank].contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why contiguous?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensors after split may not be contiguous, I think it's better be contiguous


return tensor

def scatter(group: dist.ProcessGroup, tensor: torch.Tensor, output_tensor: torch.Tensor, scatter_dim: int = 0) -> torch.Tensor:
world_size = dist.get_world_size(group)
if world_size == 1:
return tensor

rank = dist.get_rank(group)
if rank == 0:
size = tensor.size()
assert size[scatter_dim] % world_size == 0
tensors = torch.split(tensor, size[scatter_dim] // world_size, dim=scatter_dim)
scatter_list = [tensor.contiguous() for tensor in tensors]
output_tensor = scatter_list[rank]
else:
scatter_list = None
dist.scatter(tensor=output_tensor, scatter_list=scatter_list, src=0, group=group)
return output_tensor


class DifferentiableIdentity(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, group: dist.ProcessGroup):
ctx.group = group
return tensor

@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableAllReduceSum.apply(grad_output, group), None


class DifferentiableAllReduceSum(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup) -> torch.Tensor:
ctx.group = group
return all_reduce(group=group, tensor=tensor)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Any:
return grad_output, None


class DifferentiableScatter(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor:
ctx.group = group
ctx.dim = dim
return split(group=group, tensor=tensor, split_dim=dim)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return DifferentiableAllGather.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None


class DifferentiableAllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, group: dist.ProcessGroup, dim: int = -1) -> torch.Tensor:
ctx.group = group
ctx.dim = dim
return all_gather(group=group, tensor=tensor, gather_dim=dim)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return DifferentiableScatter.apply(grad_output, group=ctx.group, dim=ctx.dim), None, None


def differentiable_all_reduce_sum(tensor: torch.Tensor, group: dist.ProcessGroup):
return DifferentiableAllReduceSum.apply(tensor, group)

def differentiable_identity(tensor: torch.Tensor, group: dist.ProcessGroup):
return DifferentiableIdentity.apply(tensor, group)

def differentiable_all_gather(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1):
return DifferentiableAllGather.apply(tensor, group, dim)

def differentiable_scatter(tensor: torch.Tensor, group: dist.ProcessGroup, dim=-1):
return DifferentiableScatter.apply(tensor, group, dim)
1 change: 1 addition & 0 deletions optimum/fx/parallelization/parallel_layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .linear import RowParallelLinear, ColumnParallelLinear
Loading
Loading