Skip to content

Commit

Permalink
Automatic Model Parallelism Through FX (#1933)
Browse files Browse the repository at this point in the history
* WIP

* add dist ops

* add index propagation

* support tp for linears

* add embedding & weight tie

* address comments

* lint

* fix

* fix

* debug

* fix

* fix tests

* add experimental API

* nit

* fix api

* fix api

* format

* clean tests

* fix weight_map

* add weights loading

* format

* fix

* fix

* enable tests

* address comments
  • Loading branch information
zhenglongjiepheonix authored Aug 12, 2024
1 parent cfaece8 commit 5eaf91b
Show file tree
Hide file tree
Showing 13 changed files with 2,244 additions and 0 deletions.
65 changes: 65 additions & 0 deletions .github/workflows/test_fx_automatic_parallel.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: Automatic Model Parallelism Test on GPUs

on:
pull_request:
branches:
- main
paths:
- 'optimum/fx/parallelization/**.py'
push:
branches:
- main
paths:
- 'optimum/fx/parallelization/**.py'

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
run_gpu_tests:
strategy:
fail-fast: false
matrix:
config:
- name: GPU-enabled Optimum Test Suite
image: nvidia/cuda:12.4.1-devel-ubuntu22.04
gpu_target: ["nvidia-multi-gpu-l4-runners", "nvidia-multi-gpu-a10-runners"]

name: ${{ matrix.config.name }}
runs-on:
group: "${{matrix.gpu_target}}"

container:
image: ${{ matrix.config.image }}
options: --mount type=tmpfs,destination=/tmp --shm-size 64gb --gpus all --ipc host -v /mnt/hf_cache:/mnt/cache/
env:
NCCL_DEBUG: INFO
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
defaults:
run:
shell: bash

steps:
- uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Checkout optimum
uses: actions/checkout@v4
with:
fetch-depth: 1

- name: Run nvidia-smi
run: |
nvidia-smi
- name: Install dependencies
run: |
python3 -m pip install -U pip
python3 -m pip install torch transformers
python3 -m pip install .[tests]
- name: Run automatic model parallelism tests
run: |
pytest -s -v -o log_cli=true tests/fx/parallelization
16 changes: 16 additions & 0 deletions optimum/fx/parallelization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .api import parallelize_backend, parallelize_model
from .core import Config, ParallelExecutionCtx
126 changes: 126 additions & 0 deletions optimum/fx/parallelization/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from functools import partial
from typing import List, Union

import torch
from torch.fx import GraphModule

from .core import Config, ParallelExecutionCtx
from .passes import build_parallel_pass_pipeline
from .utils import (
MetaAwareMethodsPatcher,
download_model_from_hf,
initialize_parameter_meta,
move_model_to_device,
try_collect_weight_map,
)


def parallelize_backend(
graph_module: GraphModule, example_inputs: List[torch.Tensor], ctx: ParallelExecutionCtx, config: Config
) -> GraphModule:
ctx.example_inputs = example_inputs
pass_pipeline = build_parallel_pass_pipeline()
graph_module = pass_pipeline(graph_module=graph_module, ctx=ctx, config=config)
ctx.compile_times += 1
ctx.last_optimized_graph_module = graph_module
return graph_module


def parallelize_model(
model: Union[torch.nn.Module, str],
parallel_ctx: ParallelExecutionCtx,
*model_args,
**kwargs,
):
"""
API for automatic model parallelism through Pytorch FX.
Args:
model (Union[torch.nn.Module, str]):
Model to parallelize, could either be a module or a model id on the Huggingface Hub.
parallel_ctx (ParallelExecutionCtx):
Parallel execution context containing process groups the current process belongs to.
*model_args (Any):
Additional postional arguments for intializing the model if a model id is passed.
revision (str, defaults to `main`):
Model revision for weights downloading if a model id is passed.
cache_dir (Optional[str], defaults to `None`):
Cache directory to store downloaded weights. Defaults to None.
local_files_only (bool, defaults to `False`):
Whether to use local files only, will avoid downloading from remote if set to `True`.
skip_load_weights (bool, defaults to `False`):
Whether to skip loading weights from disk to model.
**kwargs (Dict[str, Any]):
Addtional keyword arguments for overriding fields in parallel config, model config and `Model.__init__`.
"""
revision = kwargs.pop("revision", "main")
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
skip_load_weights = kwargs.pop("skip_load_weights", False)

parallel_config = Config()
for k, v in dict(kwargs).items():
if k in parallel_config.__dict__:
setattr(parallel_config, k, v)
kwargs.pop(k)

if isinstance(model, str):
from transformers import AutoConfig

is_local = os.path.isdir(model)
if not is_local:
hf_folder = download_model_from_hf(
model_name_or_path=model,
cache_dir=cache_dir,
revision=revision,
local_files_only=local_files_only,
skip_download_weights=skip_load_weights,
)
else:
hf_folder = model

# should be able to load config using only local files
model_config, kwargs = AutoConfig.from_pretrained(
hf_folder, revision=revision, local_files_only=True, return_unused_kwargs=True, **kwargs
)

# try getting model class info from config
model_arch = model_config.architectures
model_cls = getattr(importlib.import_module("transformers"), model_arch[0])

if not skip_load_weights:
parallel_ctx.weight_map = try_collect_weight_map(model, cache_dir, hf_folder)

torch_dtype, dtype_orig = kwargs.pop("torch_dtype", None), None
if torch_dtype is not None:
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)

with MetaAwareMethodsPatcher():
model = model_cls(model_config, *model_args, **kwargs)
# TODO: remove this once support training-time trace
model.eval()

if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

move_model_to_device(model, device=parallel_ctx.current_device)
initialize_parameter_meta(model)
backend = partial(parallelize_backend, ctx=parallel_ctx, config=parallel_config)
model = torch.compile(model, fullgraph=True, backend=backend)
return model
167 changes: 167 additions & 0 deletions optimum/fx/parallelization/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.fx import GraphModule


class HashableSlice:
def __init__(self, start: Optional[int] = None, stop: Optional[int] = None, step: Optional[int] = None) -> None:
self.start = start
self.stop = stop
self.step = step

def __hash__(self) -> int:
return hash(f"{self.start},{self.stop},{self.step}")

def __eq__(self, value: object) -> bool:
return (
isinstance(value, HashableSlice)
and self.start == value.start
and self.stop == value.stop
and self.step == value.step
)

def to_slice(self) -> slice:
return slice(self.start, self.stop, self.step)


@dataclass
class ParameterSlice:
"""
A slice of parameter which corresponds to a tensor in weight dict. Only support slicing
along a specific axis (the potential parallel axis) right now.
Attributes:
- source (`Optional[str]`, defaults to `None`):
Original parameter name which can be found in the weight dict.
- shape (`Optional[Tuple]`, defaults to `None`):
Shape of parameter tensor corresponding to `source`.
- index (`slice`, defaults to `slice(None, None, None)`):
Index to slice the tensor on the parallel axis. Assume tensor in weight dict has the same
layout as their correspondings in memory.
"""

source: Optional[str] = None
shape: Optional[Tuple] = None
index: slice = slice(None, None, None)


@dataclass
class ParameterMeta:
"""
Parameter meta information.
Attributes:
- is_tied (`bool`, defaults to `False`):
Whether the parameter is shared accross multiple modules.
- is_parallel (`bool`, defaults to `False`):
Whether the parameter needs to be parallelized.
- is_modified_meta (`bool`, defaults to `False`):
Whether the meta has already been modified since initialization.
- need_initialize (`bool`, defaults to `False`):
Whether need to manually initialize weights if not provided in weight map.
- init_fn (`Optional[Callable]`, defaults to `None`):
Initialization function, can override `weight_init_fn` in `Config` if not None.
- dim (`int`, defaults to `0`):
Axis on which `mapping` is based, also the parallel axis if `is_parallel`.
- mapping (`Dict[HashableSlice, ParameterSlice]`):
Mapping between the current parameter and weight tensor stored in weight map.
"""

is_tied: bool = False
is_parallel: bool = False
is_modified_meta: bool = False
need_initialize: bool = False
init_fn: Optional[Callable] = None
dim: int = 0
mapping: Dict[HashableSlice, ParameterSlice] = field(default_factory=dict)


@dataclass
class ParallelExecutionCtx:
"""
Parallel execution context which contains runtime information.
Attributes:
- tp_group (`dist.ProcessGroup`):
Tensor parallel process group the current process belongs to.
- current_device (`torch.device`):
Device correpsonding to the current process.
- example_inputs (`List[Any]`):
A list of tensors which are used as example inputs for graphs captured by dynamo.
- parallel_layer_cache (`Dict[str, nn.Module]`):
Cache which maps layers(`nn.Linear`, `nn.Embedding`) to their parallel counterparts.
Note that we will build the cache in the first compilation process, and for recompilations
later on, we will directly replace the modules with their parallel counterparts in the cache,
because we have to make sure we don't initiate new parameters and replace original ones when
recompilation happens in training process.
- weight_map (`Dict[str, str]`):
Mapping between parameter names and their locations on disk, useful when loading weights
from disk.
- last_optimized_graph_module (`Optional[GraphModule]`, defaults to `None`):
Optimized graph module corresponding to the latest compilation.
- compile_times (`int`, defaults to `0`):
Number of compilation times happened during the whole process.
"""

tp_group: dist.ProcessGroup
current_device: torch.device
example_inputs: List[Any] = field(default_factory=list)
parallel_layer_cache: Dict[str, nn.Module] = field(default_factory=dict)
weight_map: Dict[str, str] = field(default_factory=dict)
last_optimized_graph_module: Optional[GraphModule] = None
compile_times: int = 0


@dataclass
class Config:
"""
Static config which contains instructions which do not change in runtime.
Attributes:
- lint_and_recompile (`bool`, defaults to `True`):
Whether to run graph linting and module recompilation after every pass.
- clean_markers_after_all_passes (`bool`, defaults to `True`):
Whether to clean markers of analytical passes after all passes have run.
- weight_init_fn (`Callable`, defaults to `partial(nn.init.normal_, std=0.02)`)
Initialization function of weights in `nn.Linear` and `nn.Embedding` layers,
if not provided weights loading path.
"""

lint_and_recompile: bool = True
clean_markers_after_all_passes: bool = True
weight_init_fn: Callable = partial(nn.init.normal_, std=0.02)
21 changes: 21 additions & 0 deletions optimum/fx/parallelization/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dist_ops import (
differentiable_all_gather,
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_scatter,
scatter,
)
Loading

0 comments on commit 5eaf91b

Please sign in to comment.