From f734478d63482061d26a9b169522bb46dd6a359f Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 7 Feb 2024 15:19:10 +0100 Subject: [PATCH] [WIP] llama-70b --- optimum/neuron/distributed/base.py | 17 ++++++---- optimum/neuron/distributed/utils.py | 51 +++++++++++++++++++---------- optimum/neuron/trainers.py | 26 +++++++-------- 3 files changed, 56 insertions(+), 38 deletions(-) diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 3e1a08966..8652be417 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -387,13 +387,16 @@ def predicate_func(layer): weight_map = getattr(model, "_weight_map", {}) - for fully_qualified_name, layer in model.named_modules(): - if isinstance(layer, BaseParallelLinear): - xm.master_print(fully_qualified_name) - linear_weight_info, linear_bias_weight_info = ParallelLayer._get_linear_weight_info(weight_map, fully_qualified_name) - if linear_weight_info is not None: - maybe_load_linear_weight_to_parallel_linear(layer, linear_layer_weight_info=linear_weight_info, linear_layer_bias_weight_info=linear_bias_weight_info) - xm.master_print("PARALLEL LAYERS DONE") + # for fully_qualified_name, layer in model.named_modules(): + # if isinstance(layer, BaseParallelLinear): + # xm.master_print(fully_qualified_name) + # try: + # linear_weight_info, linear_bias_weight_info = ParallelLayer._get_linear_weight_info(weight_map, fully_qualified_name) + # except ValueError: + # linear_weight_info = None + # if linear_weight_info is not None: + # maybe_load_linear_weight_to_parallel_linear(layer, linear_layer_weight_info=linear_weight_info, linear_layer_bias_weight_info=linear_bias_weight_info) + # xm.master_print("PARALLEL LAYERS DONE") with torch.no_grad(): tied_weights = {} diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index dce088c4a..ded3bf0ec 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -33,7 +33,7 @@ from ..utils import DynamicPatch, Patcher from ..utils.deprecate_utils import deprecate from ..utils.import_utils import is_neuronx_distributed_available -from ..utils.misc import download_checkpoints_in_cache +from ..utils.misc import download_checkpoints_in_cache, is_main_worker from ..utils.require_utils import ( is_torch_xla_available, requires_neuronx_distributed, @@ -291,7 +291,7 @@ def embedding_to_parallel_embedding( ) del embedding_layer.weight - + if lm_head_layer is None: return parallel_embedding_layer @@ -305,11 +305,11 @@ def maybe_load_linear_weight_to_parallel_linear( linear_layer_bias_weight_info: Optional[WeightInformation] = None, linear_layer: Optional["torch.nn.Linear"] = None, ): - if linear_layer_weight_info is not None and linear_layer is not None: + if (linear_layer_weight_info is not None or linear_layer_bias_weight_info is not None) and linear_layer is not None: raise ValueError( "Specify either a linear layer's WeightInformation, or a linear layer to copy the weights from, but not both." ) - if linear_layer_weight_info is None and linear_layer is None: + if linear_layer_weight_info is None and linear_layer_bias_weight_info is None and linear_layer is None: raise ValueError( "A linear's layer WeightInformation or a linear layer to copy the weight from need to specified." ) @@ -412,7 +412,7 @@ def linear_to_parallel_linear( linear_layer_bias_weight_info: Optional[WeightInformation] = None, embedding_weight_to_tie: Optional["torch.nn.Parameter"] = None, sequence_parallel_enabled: bool = False, - skip_weight_load: bool = True, + skip_weight_load: bool = False, device: Optional["torch.device"] = None, ) -> Union["layers.RowParallelLinear", "layers.ColumnParallelLinear"]: """ @@ -724,22 +724,37 @@ def from_pretrained_for_mp( if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs: adapter_kwargs["token"] = token - filenames, sharded_metadata = download_checkpoints_in_cache( - pretrained_model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - token=token, - revision=revision, - use_safetensors=use_safetensors, - use_safetensors_in_priority=True, - convert_to_safetensors=True, - **kwargs, - ) - import torch_xla.core.xla_model as xm + if is_main_worker(): + filenames, sharded_metadata = download_checkpoints_in_cache( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + use_safetensors_in_priority=True, + convert_to_safetensors=True, + **kwargs, + ) + xm.rendezvous("waiting after download and conversion") + + if not is_main_worker(): + filenames, sharded_metadata = download_checkpoints_in_cache( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + use_safetensors_in_priority=True, + convert_to_safetensors=True, + **kwargs, + ) if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index ebfe94b2a..60f7caa71 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -127,7 +127,7 @@ import torch_xla.distributed.xla_backend as xbn if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): - _ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path() + # _ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path() # _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set. # if _ORIGINAL_NEURON_CACHE_PATH is not None: @@ -196,18 +196,18 @@ def __init__(self, *args, **kwargs): if self.args.local_rank <= 0: logger.setLevel(logging.INFO) - rank = xm.get_ordinal() - push = rank <= 0 and not is_precompilation() and not self.args.skip_cache_push - fetch = rank <= 0 or self.args.mp_plugin.should_parallelize - - callback = NeuronCacheCallback( - tmp_neuron_cache=_TMP_NEURON_CACHE_PATH, - original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH, - fetch=fetch, - push=push, - wait_for_everyone_on_fetch=True, - wait_for_everyone_on_push=True, - ) + # rank = xm.get_ordinal() + # push = rank <= 0 and not is_precompilation() and not self.args.skip_cache_push + # fetch = rank <= 0 or self.args.mp_plugin.should_parallelize + + # callback = NeuronCacheCallback( + # tmp_neuron_cache=_TMP_NEURON_CACHE_PATH, + # original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH, + # fetch=fetch, + # push=push, + # wait_for_everyone_on_fetch=True, + # wait_for_everyone_on_push=True, + # ) # self.add_callback(callback) # Make the model Neuron-compatible for generation.