From 6eeeaa0070153142aa27101ceeeeb41716a38600 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 5 Feb 2024 14:37:03 +0100 Subject: [PATCH] [WIP] llama-70b --- optimum/neuron/accelerate/accelerator.py | 10 +- .../neuron/accelerate/utils/dataclasses.py | 2 + optimum/neuron/distributed/base.py | 18 ++- optimum/neuron/distributed/parallel_layers.py | 4 +- optimum/neuron/distributed/utils.py | 122 +++++++++++++++++- optimum/neuron/trainer_callback.py | 10 +- optimum/neuron/trainers.py | 111 +++++++++------- optimum/neuron/training_args.py | 3 +- optimum/neuron/utils/training_utils.py | 53 +++++++- 9 files changed, 264 insertions(+), 69 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index d7bbd6ae2..d1bcf637c 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -173,7 +173,11 @@ def __init__(self, *args, mp_plugin: Optional[ModelParallelismPlugin] = None, ze self.gradient_accumulation_steps = num_steps def _prepare_data_loader_for_distributed( - self, data_loader: DataLoader, num_replicas: int, rank: int, force_drop_last: bool, + self, + data_loader: DataLoader, + num_replicas: int, + rank: int, + force_drop_last: bool, ) -> DataLoader: # TODO: make it more robust, similar to the prepare_data_loader function in `accelerate`. if isinstance(data_loader.sampler, DistributedSampler): @@ -224,7 +228,9 @@ def prepare_data_loader(self, data_loader: DataLoader, device_placement: Optiona num_replicas = xm.xrt_world_size() rank = xm.get_local_ordinal() if self.state.num_processes > 1: - data_loader = self._prepare_data_loader_for_distributed(data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last) + data_loader = self._prepare_data_loader_for_distributed( + data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last + ) # No need to wrap the dataloader if we are using pipeline parallelism. if self.state.mp_plugin.pipeline_parallel_size == 1: data_loader = MpDeviceLoader(data_loader, self.device) diff --git a/optimum/neuron/accelerate/utils/dataclasses.py b/optimum/neuron/accelerate/utils/dataclasses.py index f4d0dc0dd..7e88106eb 100644 --- a/optimum/neuron/accelerate/utils/dataclasses.py +++ b/optimum/neuron/accelerate/utils/dataclasses.py @@ -147,6 +147,7 @@ class ModelParallelismPlugin: pipeline_parallel_size: int = 1 pipeline_parallel_num_microbatches: int = 1 pipeline_parallel_use_zero1_optimizer: bool = False + gradient_checkpointing: bool = False checkpoint_dir: Optional[Union[str, Path]] = None def __post_init__(self): @@ -176,6 +177,7 @@ def parallelize_model( sequence_parallel_enabled=self.sequence_parallel_enabled, pipeline_parallel_num_microbatches=self.pipeline_parallel_num_microbatches, pipeline_parallel_use_zero1_optimizer=self.pipeline_parallel_use_zero1_optimizer, + gradient_checkpointing=self.gradient_checkpointing, checkpoint_dir=self.checkpoint_dir, ) return parallelized_model diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 9b624f065..e413bdda5 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -21,7 +21,7 @@ from dataclasses import asdict from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, Tuple, Type, Union, Callable +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Set, Tuple, Type, Union import torch from transformers import PreTrainedModel @@ -29,9 +29,9 @@ from ...utils import logging from ..utils import is_neuronx_distributed_available, is_torch_xla_available +from ..utils.misc import is_main_worker from ..utils.patching import Patcher from ..utils.require_utils import requires_neuronx_distributed, requires_torch_xla -from ..utils.misc import is_main_worker from .parallel_layers import ( IOSequenceParallelizer, LayerNormSequenceParallelizer, @@ -42,6 +42,7 @@ TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, WeightInformation, + apply_checkpoint, initialize_parallel_linear, initialize_torch_nn_module, linear_to_parallel_linear, @@ -243,7 +244,7 @@ def _parallelize( device: Optional["torch.device"] = None, parallelize_embeddings: bool = True, sequence_parallel_enabled: bool = False, - should_parallelize_predicate_func: Optional[Callable[["torch.nn.Module"], "torch.nn.Module"]] = None + should_parallelize_predicate_func: Optional[Callable[["torch.nn.Module"], "torch.nn.Module"]] = None, ) -> "PreTrainedModel": """ Parallelizes the model by transforming regular layer into their parallel counterparts. @@ -275,6 +276,7 @@ def parallelize( pipeline_parallel_input_names: Optional[Union[Tuple[str, ...], List[str]]] = None, pipeline_parallel_num_microbatches: int = 1, pipeline_parallel_use_zero1_optimizer: bool = False, + gradient_checkpointing: bool = False, checkpoint_dir: Optional[Union[str, Path]] = None, ) -> "PreTrainedModel": """ @@ -299,6 +301,8 @@ def parallelize( pipeline_parallel_use_zero1_optimizer (`bool`, defaults to `False`): When zero-1 optimizer is used, set this to True, so the PP model will understand that zero-1 optimizer will handle data parallel gradient averaging. + gradient_checkpointing (`bool`, defaults to `False`): + TODO checkpoint_dir (`Optional[Union[str, Path]]`): Path to a sharded checkpoint. If specified, the checkpoint weights will be loaded to the parallelized model. @@ -330,14 +334,15 @@ def parallelize( model, remove_duplicate=True ) - name_to_parameter = dict(named_parameters(model, remove_duplicate=False)) parameter_to_name = {p: n for n, p in name_to_parameter.items()} + xm.master_print(name_to_parameter.keys()) + def predicate_func(layer): for n, p in layer.named_parameters(): if p not in parameter_to_name: - print(n) + xm.master_print(n) return False names = {parameter_to_name[p] for p in layer.parameters()} return names < names_of_the_parameters_to_consider @@ -527,13 +532,14 @@ def predicate_func(layer): leaf_module_cls=cls.PIPELINE_PARALLELISM_SPECS_CLS.leaf_module_cls(), use_zero1_optimizer=pipeline_parallel_use_zero1_optimizer, ) + if gradient_checkpointing: + apply_checkpoint(model) xm.rendezvous("End of pipeline paralellism") if checkpoint_dir is not None: cls.load_model_checkpoint(model, checkpoint_dir) - return model @classmethod diff --git a/optimum/neuron/distributed/parallel_layers.py b/optimum/neuron/distributed/parallel_layers.py index 1f4cb24bc..554460d11 100644 --- a/optimum/neuron/distributed/parallel_layers.py +++ b/optimum/neuron/distributed/parallel_layers.py @@ -27,8 +27,8 @@ from ...utils import NormalizedConfigManager, logging from ..utils import patch_everywhere, patch_within_function -from ..utils.require_utils import requires_neuronx_distributed from ..utils.misc import is_main_worker +from ..utils.require_utils import requires_neuronx_distributed from .utils import ( GroupedQueryAttentionInfo, WeightInformation, @@ -238,8 +238,6 @@ def _transform( embedding_layer = layer.get_submodule(cls.EMBEDDING_NAME) tp_size = parallel_state.get_tensor_model_parallel_size() if embedding_layer.num_embeddings % tp_size != 0: - import torch_xla.core.xla_model as xm - if is_main_worker(): logger.warning( f"Embedding parallelization for TP was skipped because the tensor parallel size ({tp_size}) does not " diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index f5837e8ff..fc3a98703 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -22,26 +22,41 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union import torch +from torch.distributed.utils import _replace_by_prefix from transformers import PretrainedConfig from transformers.utils import is_peft_available +from ...utils import logging from ..utils import DynamicPatch, Patcher from ..utils.deprecate_utils import deprecate from ..utils.import_utils import is_neuronx_distributed_available from ..utils.misc import download_checkpoints_in_cache -from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla +from ..utils.require_utils import ( + is_torch_xla_available, + requires_neuronx_distributed, + requires_safetensors, + requires_torch_xla, +) if is_neuronx_distributed_available(): from neuronx_distributed.parallel_layers import layers + from neuronx_distributed.parallel_layers.parallel_state import rmsg + from neuronx_distributed.pipeline import NxDPPModel + +if is_torch_xla_available(): + from torch_xla.utils.checkpoint import checkpoint as torch_checkpoint if TYPE_CHECKING: from transformers import PreTrainedModel +logger = logging.get_logger() + + TENSOR_PARALLEL_SHARDS_DIR_NAME = "tensor_parallel_shards" @@ -886,3 +901,106 @@ def is_tied(self): def is_sharded(self): return self.kind == "sharded" + +# The following code for gradient checkpointing was taken from: +# https://github.com/aws-neuron/neuronx-distributed/blob/main/examples/training/llama2/tp_pp_llama2_hf_pretrain/activation_checkpoint.py + +_CHECKPOINT_WRAPPED_MODULE = "mod" +_CHECKPOINT_PREFIX = _CHECKPOINT_WRAPPED_MODULE + "." + + +class CheckPointWrapper(torch.nn.Module): + def __init__(self, mod) -> None: + super().__init__() + self.mod = mod + # state_dict post hook to remove prefix to allow loading into a + # non-checkpoint wrapped module. + self._register_state_dict_hook(self._post_state_dict_hook) + # load_state_dict pre-hook to allow loading back into + # checkpoint-wrapped module. + self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook, with_module=True) + + def forward(self, *args, **kwargs): + ordered_args = list(args) + for value in kwargs.values(): + ordered_args += [value] + + # Note: checkpoint cannot accept kwargs + return torch_checkpoint(self.mod, *ordered_args, use_reentrant=True) + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[Tuple[str, torch.nn.Parameter]]: + """ + Overrides :meth:`named_parameters()` to intercept parameter names and + remove all occurrences of ``_CHECKPOINT_PREFIX``. + """ + for param_name, param in super().named_parameters(*args, **kwargs): + updated_name = param_name.replace(_CHECKPOINT_PREFIX, "") + yield updated_name, param + + def named_modules(self, *args, **kwargs): + for module_name, module in super().named_modules(*args, **kwargs): + updated_name = module_name.replace(_CHECKPOINT_PREFIX, "") + yield updated_name, module + + @staticmethod + def _post_state_dict_hook( + module: "torch.nn.Module", + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> Dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this + FSDP module is executed. For ``checkpoint_wrapper``, it will strip + checkpoint-wrapped module prefix so that this module can be loaded into + non-checkpointed modules. It would still be able to be loaded into + checkpoint-wrapped modules as this class adds the prefix back before + loading the state_dict. + """ + _replace_by_prefix(state_dict, f"{prefix}{_CHECKPOINT_PREFIX}", prefix) + return state_dict + + @staticmethod + def _pre_load_state_dict_hook( + module: "torch.nn.Module", + state_dict: Dict[str, Any], + prefix: str, + *args: Any, + ) -> None: + """ + ``_pre_state_dict_hook` is called before ``self._load_from_state_dict()`` + is called. For ``checkpoint_wrapper``, it will add back the module + prefix so that non-checkpointed modules can be loaded into + checkpoint_wrapper modules properly. + """ + _replace_by_prefix(state_dict, prefix, prefix + f"{_CHECKPOINT_PREFIX}") + + +def apply_checkpoint(dist_model: "NxDPPModel", layers_to_checkpoint: Optional[List["torch.nn.Module"]] = None): + checkpoint_wrapper_added = False + if layers_to_checkpoint is not None and len(layers_to_checkpoint) == 0: + raise RuntimeError(rmsg(f"invalid input layers_to_checkpoint {layers_to_checkpoint}, can't be empty")) + for name, module in dist_model.local_module.named_children(): + # checkpoint layers that are provided in input + # if layers not provide in input, then checkpoint if it is transformer layer + if (layers_to_checkpoint and name in layers_to_checkpoint) or ( + not layers_to_checkpoint and type(module) == dist_model.transformer_layer_cls + ): + # add_module replaces old module with our own custom module. + # https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.add_module + dist_model.local_module.add_module(name, CheckPointWrapper(module)) + checkpoint_wrapper_added = True + if layers_to_checkpoint is not None and not checkpoint_wrapper_added: + logger.warning(rmsg(f"layers_to_checkpoint {layers_to_checkpoint} do not exist in the graph")) + elif layers_to_checkpoint is None and not checkpoint_wrapper_added: + logger.warning( + rmsg( + "During applying activation checkpointing, transformer_layer_cls " + f"{dist_model.transformer_layer_cls.__name__} can not be found in stage " + f"{dist_model.pipeline_parallel_rank}, skipping..." + ) + ) diff --git a/optimum/neuron/trainer_callback.py b/optimum/neuron/trainer_callback.py index fe3b31127..483cf9f73 100644 --- a/optimum/neuron/trainer_callback.py +++ b/optimum/neuron/trainer_callback.py @@ -311,7 +311,11 @@ def synchronize_temporary_neuron_cache(self): # pushed_directories = set() allow_patterns = [file.as_posix() for file in files] push_to_cache_on_hub( - neuron_hash, self.tmp_neuron_cache_path, cache_repo_id=self.cache_repo_id, local_path_to_path_in_repo="default", allow_patterns=allow_patterns, + neuron_hash, + self.tmp_neuron_cache_path, + cache_repo_id=self.cache_repo_id, + local_path_to_path_in_repo="default", + allow_patterns=allow_patterns, ) for path in files: @@ -387,7 +391,9 @@ def on_train_begin(self, args: "TrainingArguments", state: TrainerState, control neuron_hash = entry["neuron_hash"] module_dir = Path(entry["directory"]) cache_dir = module_dir.parent - filenames = [file.as_posix() for file in list_files_in_neuron_cache(module_dir, only_relevant_files=True)] + filenames = [ + file.as_posix() for file in list_files_in_neuron_cache(module_dir, only_relevant_files=True) + ] success = True try: push_to_cache_on_hub( diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 9021c71ac..158a37dc6 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -75,6 +75,7 @@ patch_within_function, ) from .utils.cache_utils import get_neuron_cache_path, set_neuron_cache_path +from .utils.misc import is_main_worker from .utils.require_utils import requires_neuronx_distributed from .utils.training_utils import ( TRANSFORMERS_MIN_VERSION_USE_ACCELERATE, @@ -83,8 +84,8 @@ is_topology_supported, patch_generation_mixin_to_neuron_generation_mixin, prepare_environment_for_neuron, - set_neuron_cc_optlevel_for_model, set_neuron_cc_flags_for_model, + set_neuron_cc_optlevel_for_model, skip_first_batches, torch_xla_safe_save_file, ) @@ -219,13 +220,13 @@ def mp_enabled(self): return self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM def prepare_args_for_precompilation(self, args: "TrainingArguments"): - if args.num_train_epochs != 1: + if is_main_worker() and args.num_train_epochs != 1: logger.info("Setting the number of epochs for precompilation to 1.") args.num_train_epochs = 1 - if args.do_eval is True: + if is_main_worker() and args.do_eval is True: logger.info("Disabling evaluation during precompilation as this is not well supported yet.") args.do_eval = False - if args.do_predict is True: + if is_main_worker() and args.do_predict is True: logger.info("Disabling prediction during precompilation as this is not well supported yet.") args.do_predict = False @@ -451,7 +452,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for def _save_xla(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir - logger.info(f"Saving model checkpoint to {output_dir}") + if is_main_worker(): + logger.info(f"Saving model checkpoint to {output_dir}") if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) @@ -461,7 +463,8 @@ def _save_xla(self, output_dir: Optional[str] = None): # They can then be reloaded using `from_pretrained()` xm.rendezvous("saving_checkpoint") if self.accelerator.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: - logger.info("Model parallelism is enabled, only saving the model sharded state dict.") + if is_main_worker(): + logger.info("Model parallelism is enabled, only saving the model sharded state dict.") # TODO: how to handle pp? if isinstance(self.model, PreTrainedModel): from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size @@ -488,7 +491,8 @@ def _save_xla(self, output_dir: Optional[str] = None): save_function=xm.save, ) else: - logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if is_main_worker(): + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") state_dict = self.model.state_dict() xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: @@ -510,7 +514,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa # Push to the Hub when `save_model` is called by the user. if self.args.push_to_hub and not _internal_call: self.push_to_hub(commit_message="Model save") - else: + elif is_main_worker(): logger.info("Skipping trainer.save_model() while running under neuron_parallel_compile") def _save_checkpoint(self, model, trial, metrics=None): @@ -634,7 +638,8 @@ def _inner_training_loop( self.accelerator.free_memory() self._train_batch_size = batch_size - logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + if is_main_worker(): + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -726,7 +731,8 @@ def _inner_training_loop( self.state.save_steps = args.save_steps # Activate gradient checkpointing if needed - if args.gradient_checkpointing: + # It is handled differentlt if pipeline parallelism is enabled. + if args.gradient_checkpointing and args.pipeline_parallel_size == 1: if args.gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {} else: @@ -779,16 +785,22 @@ def _inner_training_loop( # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. # Train! - logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples:,}") - logger.info(f" Num Epochs = {num_train_epochs:,}") - logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") - if self.args.per_device_train_batch_size != self._train_batch_size: - logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_steps:,}") - logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + parameter_count = get_model_param_count(model, trainable_only=True) + if is_main_worker(): + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info( + f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}" + ) + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}" + ) + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {parameter_count:,}") self.state.epoch = 0 start_time = time.time() @@ -808,14 +820,15 @@ def _inner_training_loop( else: steps_trained_in_current_epoch = 0 - logger.info(" Continuing training from checkpoint, will skip to saved global_step") - logger.info(f" Continuing training from epoch {epochs_trained}") - logger.info(f" Continuing training from global step {self.state.global_step}") - if not args.ignore_data_skip: - logger.info( - f" Will skip the first {epochs_trained} epochs then the first" - f" {steps_trained_in_current_epoch} batches in the first epoch." - ) + if is_main_worker(): + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) # Update the references self.callback_handler.model = self.model @@ -999,11 +1012,12 @@ def _inner_training_loop( if self.control.should_epoch_stop or self.control.should_training_stop: break if step < 0: - logger.warning( - "There seems to be not a single sample in your epoch_iterator, stopping training at step" - f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" - f" num_steps ({max_steps}) higher than the number of available samples." - ) + if is_main_worker(): + logger.warning( + "There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) @@ -1025,7 +1039,8 @@ def _inner_training_loop( # Clean the state at the end of training delattr(self, "_past") - logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if is_main_worker(): + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sure the model has been saved by process 0. if is_torch_xla_available(): @@ -1065,7 +1080,8 @@ def _inner_training_loop( if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: for checkpoint in checkpoints_sorted: if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): - logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + if is_main_worker(): + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint) self.control = self.callback_handler.on_train_end(args, self.state, self.control) @@ -1139,17 +1155,18 @@ def evaluation_loop( batch_size = self.args.eval_batch_size - logger.info(f"***** Running {description} *****") - dp_size = get_data_parallel_size() - logger.info(f" Num data parallel workers = {dp_size}") - if has_length(dataloader): - num_examples = self.num_examples(dataloader) - total_num_examples = num_examples * dp_size - logger.info(f" Per data parallel worker num examples = {num_examples}") - logger.info(f" Total num examples = {total_num_examples}") - else: - logger.info(" Num examples: Unknown") - logger.info(f" Batch size = {batch_size}") + if is_main_worker(): + logger.info(f"***** Running {description} *****") + dp_size = get_data_parallel_size() + logger.info(f" Num data parallel workers = {dp_size}") + if has_length(dataloader): + num_examples = self.num_examples(dataloader) + total_num_examples = num_examples * dp_size + logger.info(f" Per data parallel worker num examples = {num_examples}") + logger.info(f" Total num examples = {total_num_examples}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") if not is_nxdppmodel: model.eval() @@ -1187,7 +1204,7 @@ def evaluation_loop( batch_size = observed_batch_size if is_nxdppmodel and observed_batch_size % model.num_microbatches != 0: - if xm.get_local_ordinal() == 0: + if is_main_worker() == 0: logger.warning( "Skipping the evaluation step because the pipeline number of microbatches " f"({model.num_microbatches}) does not divide the batch size ({observed_batch_size})." diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index c0d200675..85a8a4a58 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -38,8 +38,8 @@ from .accelerate import NeuronAcceleratorState, NeuronPartialState from .accelerate.utils import ModelParallelismPlugin, patch_accelerate_is_tpu_available from .utils import is_accelerate_available, is_torch_xla_available -from .utils.training_utils import TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP from .utils.patching import Patcher +from .utils.training_utils import TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP if is_sagemaker_mp_enabled(): @@ -150,6 +150,7 @@ def __post_init__(self): pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_num_microbatches=self.pipeline_parallel_num_microbatches, pipeline_parallel_use_zero1_optimizer=self.zero_1, + gradient_checkpointing=self.gradient_checkpointing, checkpoint_dir=resume_from_checkpoint, ) diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index c6157ba22..7b62fb448 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -44,15 +44,18 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_MAPPING_NAMES, ) -from transformers.trainer_pt_utils import get_model_param_count as transformers_get_model_param_count from transformers.utils.logging import set_verbosity as set_verbosity_transformers from ...utils.logging import set_verbosity as set_verbosity_optimum from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin -from . import is_torch_xla_available +from . import is_neuronx_distributed_available, is_torch_xla_available from .require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla +if is_neuronx_distributed_available(): + from neuronx_distributed.pipeline import NxDPPModel + + if TYPE_CHECKING: from transformers import PreTrainedModel @@ -378,7 +381,45 @@ def torch_xla_safe_save_file( save_file(cpu_data, filename, metadata=metadata) -def get_model_param_count(model, trainable_only=False): - """Wrapper around `transformers.trainer_pt_utils.get_model_param_count` to handle tensor parallelism.""" - # TODO: make it work for TP - return transformers_get_model_param_count(model, trainable_only=trainable_only) +@requires_neuronx_distributed +def get_model_param_count(model: Union[torch.nn.Module, "NxDPPModel"], trainable_only: bool = False): + """Counts the number of parameters of `model`.""" + import torch_xla.core.xla_model as xm + from neuronx_distributed.parallel_layers.parallel_state import ( + get_pipeline_model_parallel_group, + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_size, + get_tensor_model_parallel_size, + ) + from neuronx_distributed.pipeline import NxDPPModel + from neuronx_distributed.pipeline.partition import analyze_shared_weights_across_stages + + if isinstance(model, NxDPPModel): + named_parameters = model.local_named_parameters() + shared = analyze_shared_weights_across_stages(model.traced_model, model.partitions) + shared_parameters_across_pipeline_stages = { + t[0]: t[1] for shared_parameter_info in shared for t in shared_parameter_info + } + else: + named_parameters = model.named_parameters() + shared_parameters_across_pipeline_stages = {} + + pp_rank = get_pipeline_model_parallel_rank() + + def numel(parameter_name, parameter) -> int: + should_count_param = shared_parameters_across_pipeline_stages.get(parameter_name, pp_rank) == pp_rank + + num_elements = parameter.numel() + if getattr(parameter, "tensor_model_parallel", False): + num_elements *= get_tensor_model_parallel_size() + + return num_elements if should_count_param else 0 + + param_count = sum(numel(n, p) for n, p in named_parameters if not trainable_only or p.requires_grad) + + if get_pipeline_model_parallel_size() > 1: + param_count = torch.tensor(param_count, dtype=torch.double).to(xm.xla_device()) + param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True)) + param_count = param_count.detach().item() + + return param_count