Skip to content

Commit

Permalink
[WIP] llama-70b
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Feb 5, 2024
1 parent eaae663 commit 6eeeaa0
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 69 deletions.
10 changes: 8 additions & 2 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def __init__(self, *args, mp_plugin: Optional[ModelParallelismPlugin] = None, ze
self.gradient_accumulation_steps = num_steps

def _prepare_data_loader_for_distributed(
self, data_loader: DataLoader, num_replicas: int, rank: int, force_drop_last: bool,
self,
data_loader: DataLoader,
num_replicas: int,
rank: int,
force_drop_last: bool,
) -> DataLoader:
# TODO: make it more robust, similar to the prepare_data_loader function in `accelerate`.
if isinstance(data_loader.sampler, DistributedSampler):
Expand Down Expand Up @@ -224,7 +228,9 @@ def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optiona
num_replicas = xm.xrt_world_size()
rank = xm.get_local_ordinal()
if self.state.num_processes > 1:
data_loader = self._prepare_data_loader_for_distributed(data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last)
data_loader = self._prepare_data_loader_for_distributed(
data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last
)
# No need to wrap the dataloader if we are using pipeline parallelism.
if self.state.mp_plugin.pipeline_parallel_size == 1:
data_loader = MpDeviceLoader(data_loader, self.device)
Expand Down
2 changes: 2 additions & 0 deletions optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class ModelParallelismPlugin:
pipeline_parallel_size: int = 1
pipeline_parallel_num_microbatches: int = 1
pipeline_parallel_use_zero1_optimizer: bool = False
gradient_checkpointing: bool = False
checkpoint_dir: Optional[Union[str, Path]] = None

def __post_init__(self):
Expand Down Expand Up @@ -176,6 +177,7 @@ def parallelize_model(
sequence_parallel_enabled=self.sequence_parallel_enabled,
pipeline_parallel_num_microbatches=self.pipeline_parallel_num_microbatches,
pipeline_parallel_use_zero1_optimizer=self.pipeline_parallel_use_zero1_optimizer,
gradient_checkpointing=self.gradient_checkpointing,
checkpoint_dir=self.checkpoint_dir,
)
return parallelized_model
18 changes: 12 additions & 6 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@
from dataclasses import asdict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union, Callable
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Type, Union

import torch
from transformers import PreTrainedModel
from transformers.utils import WEIGHTS_NAME

from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
from ..utils.misc import is_main_worker
from ..utils.patching import Patcher
from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla
from ..utils.misc import is_main_worker
from .parallel_layers import (
IOSequenceParallelizer,
LayerNormSequenceParallelizer,
Expand All @@ -42,6 +42,7 @@
TENSOR_PARALLEL_SHARDS_DIR_NAME,
ParameterMetadata,
WeightInformation,
apply_checkpoint,
initialize_parallel_linear,
initialize_torch_nn_module,
linear_to_parallel_linear,
Expand Down Expand Up @@ -243,7 +244,7 @@ def _parallelize(
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
sequence_parallel_enabled: bool = False,
should_parallelize_predicate_func: Optional[Callable[["torch.nn.Module"], "torch.nn.Module"]] = None
should_parallelize_predicate_func: Optional[Callable[["torch.nn.Module"], "torch.nn.Module"]] = None,
) -> "PreTrainedModel":
"""
Parallelizes the model by transforming regular layer into their parallel counterparts.
Expand Down Expand Up @@ -275,6 +276,7 @@ def parallelize(
pipeline_parallel_input_names: Optional[Union[Tuple[str, ...], List[str]]] = None,
pipeline_parallel_num_microbatches: int = 1,
pipeline_parallel_use_zero1_optimizer: bool = False,
gradient_checkpointing: bool = False,
checkpoint_dir: Optional[Union[str, Path]] = None,
) -> "PreTrainedModel":
"""
Expand All @@ -299,6 +301,8 @@ def parallelize(
pipeline_parallel_use_zero1_optimizer (`bool`, defaults to `False`):
When zero-1 optimizer is used, set this to True, so the PP model will understand that zero-1 optimizer
will handle data parallel gradient averaging.
gradient_checkpointing (`bool`, defaults to `False`):
TODO
checkpoint_dir (`Optional[Union[str, Path]]`):
Path to a sharded checkpoint. If specified, the checkpoint weights will be loaded to the parallelized
model.
Expand Down Expand Up @@ -330,14 +334,15 @@ def parallelize(
model, remove_duplicate=True
)


name_to_parameter = dict(named_parameters(model, remove_duplicate=False))
parameter_to_name = {p: n for n, p in name_to_parameter.items()}

xm.master_print(name_to_parameter.keys())

def predicate_func(layer):
for n, p in layer.named_parameters():
if p not in parameter_to_name:
print(n)
xm.master_print(n)
return False
names = {parameter_to_name[p] for p in layer.parameters()}
return names < names_of_the_parameters_to_consider
Expand Down Expand Up @@ -527,13 +532,14 @@ def predicate_func(layer):
leaf_module_cls=cls.PIPELINE_PARALLELISM_SPECS_CLS.leaf_module_cls(),
use_zero1_optimizer=pipeline_parallel_use_zero1_optimizer,
)
if gradient_checkpointing:
apply_checkpoint(model)

xm.rendezvous("End of pipeline paralellism")

if checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)


return model

@classmethod
Expand Down
4 changes: 1 addition & 3 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

from ...utils import NormalizedConfigManager, logging
from ..utils import patch_everywhere, patch_within_function
from ..utils.require_utils import requires_neuronx_distributed
from ..utils.misc import is_main_worker
from ..utils.require_utils import requires_neuronx_distributed
from .utils import (
GroupedQueryAttentionInfo,
WeightInformation,
Expand Down Expand Up @@ -238,8 +238,6 @@ def _transform(
embedding_layer = layer.get_submodule(cls.EMBEDDING_NAME)
tp_size = parallel_state.get_tensor_model_parallel_size()
if embedding_layer.num_embeddings % tp_size != 0:
import torch_xla.core.xla_model as xm

if is_main_worker():
logger.warning(
f"Embedding parallelization for TP was skipped because the tensor parallel size ({tp_size}) does not "
Expand Down
122 changes: 120 additions & 2 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,41 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union

import torch
from torch.distributed.utils import _replace_by_prefix
from transformers import PretrainedConfig
from transformers.utils import is_peft_available

from ...utils import logging
from ..utils import DynamicPatch, Patcher
from ..utils.deprecate_utils import deprecate
from ..utils.import_utils import is_neuronx_distributed_available
from ..utils.misc import download_checkpoints_in_cache
from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla
from ..utils.require_utils import (
is_torch_xla_available,
requires_neuronx_distributed,
requires_safetensors,
requires_torch_xla,
)


if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers import layers
from neuronx_distributed.parallel_layers.parallel_state import rmsg
from neuronx_distributed.pipeline import NxDPPModel

if is_torch_xla_available():
from torch_xla.utils.checkpoint import checkpoint as torch_checkpoint

if TYPE_CHECKING:
from transformers import PreTrainedModel


logger = logging.get_logger()


TENSOR_PARALLEL_SHARDS_DIR_NAME = "tensor_parallel_shards"


Expand Down Expand Up @@ -886,3 +901,106 @@ def is_tied(self):
def is_sharded(self):
return self.kind == "sharded"


# The following code for gradient checkpointing was taken from:
# https://github.com/aws-neuron/neuronx-distributed/blob/main/examples/training/llama2/tp_pp_llama2_hf_pretrain/activation_checkpoint.py

_CHECKPOINT_WRAPPED_MODULE = "mod"
_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "."


class CheckPointWrapper(torch.nn.Module):
def __init__(self, mod) -> None:
super().__init__()
self.mod = mod
# state_dict post hook to remove prefix to allow loading into a
# non-checkpoint wrapped module.
self._register_state_dict_hook(self._post_state_dict_hook)
# load_state_dict pre-hook to allow loading back into
# checkpoint-wrapped module.
self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook, with_module=True)

def forward(self, *args, **kwargs):
ordered_args = list(args)
for value in kwargs.values():
ordered_args += [value]

# Note: checkpoint cannot accept kwargs
return torch_checkpoint(self.mod, *ordered_args, use_reentrant=True)

def named_parameters(
self,
*args,
**kwargs,
) -> Iterator[Tuple[str, torch.nn.Parameter]]:
"""
Overrides :meth:`named_parameters()` to intercept parameter names and
remove all occurrences of ``_CHECKPOINT_PREFIX``.
"""
for param_name, param in super().named_parameters(*args, **kwargs):
updated_name = param_name.replace(_CHECKPOINT_PREFIX, "")
yield updated_name, param

def named_modules(self, *args, **kwargs):
for module_name, module in super().named_modules(*args, **kwargs):
updated_name = module_name.replace(_CHECKPOINT_PREFIX, "")
yield updated_name, module

@staticmethod
def _post_state_dict_hook(
module: "torch.nn.Module",
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
"""
_post_state_dict_hook() is called after the state_dict() of this
FSDP module is executed. For ``checkpoint_wrapper``, it will strip
checkpoint-wrapped module prefix so that this module can be loaded into
non-checkpointed modules. It would still be able to be loaded into
checkpoint-wrapped modules as this class adds the prefix back before
loading the state_dict.
"""
_replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix)
return state_dict

@staticmethod
def _pre_load_state_dict_hook(
module: "torch.nn.Module",
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> None:
"""
``_pre_state_dict_hook` is called before ``self._load_from_state_dict()``
is called. For ``checkpoint_wrapper``, it will add back the module
prefix so that non-checkpointed modules can be loaded into
checkpoint_wrapper modules properly.
"""
_replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}")


def apply_checkpoint(dist_model: "NxDPPModel", layers_to_checkpoint: Optional[List["torch.nn.Module"]] = None):
checkpoint_wrapper_added = False
if layers_to_checkpoint is not None and len(layers_to_checkpoint) == 0:
raise RuntimeError(rmsg(f"invalid input layers_to_checkpoint {layers_to_checkpoint}, can't be empty"))
for name, module in dist_model.local_module.named_children():
# checkpoint layers that are provided in input
# if layers not provide in input, then checkpoint if it is transformer layer
if (layers_to_checkpoint and name in layers_to_checkpoint) or (
not layers_to_checkpoint and type(module) == dist_model.transformer_layer_cls
):
# add_module replaces old module with our own custom module.
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.add_module
dist_model.local_module.add_module(name, CheckPointWrapper(module))
checkpoint_wrapper_added = True
if layers_to_checkpoint is not None and not checkpoint_wrapper_added:
logger.warning(rmsg(f"layers_to_checkpoint {layers_to_checkpoint} do not exist in the graph"))
elif layers_to_checkpoint is None and not checkpoint_wrapper_added:
logger.warning(
rmsg(
"During applying activation checkpointing, transformer_layer_cls "
f"{dist_model.transformer_layer_cls.__name__} can not be found in stage "
f"{dist_model.pipeline_parallel_rank}, skipping..."
)
)
10 changes: 8 additions & 2 deletions optimum/neuron/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,11 @@ def synchronize_temporary_neuron_cache(self):
# pushed_directories = set()
allow_patterns = [file.as_posix() for file in files]
push_to_cache_on_hub(
neuron_hash, self.tmp_neuron_cache_path, cache_repo_id=self.cache_repo_id, local_path_to_path_in_repo="default", allow_patterns=allow_patterns,
neuron_hash,
self.tmp_neuron_cache_path,
cache_repo_id=self.cache_repo_id,
local_path_to_path_in_repo="default",
allow_patterns=allow_patterns,
)

for path in files:
Expand Down Expand Up @@ -387,7 +391,9 @@ def on_train_begin(self, args: "TrainingArguments", state: TrainerState, control
neuron_hash = entry["neuron_hash"]
module_dir = Path(entry["directory"])
cache_dir = module_dir.parent
filenames = [file.as_posix() for file in list_files_in_neuron_cache(module_dir, only_relevant_files=True)]
filenames = [
file.as_posix() for file in list_files_in_neuron_cache(module_dir, only_relevant_files=True)
]
success = True
try:
push_to_cache_on_hub(
Expand Down
Loading

0 comments on commit 6eeeaa0

Please sign in to comment.