From 5ac76e67b6148a003b89ab7dbdbda3ee7a1b67dd Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 26 Feb 2024 18:58:05 +0100 Subject: [PATCH] Fix --- optimum/commands/neuron/cache.py | 16 +++++++++++++- optimum/neuron/accelerate/state.py | 21 +++++++------------ optimum/neuron/distributed/base.py | 4 +++- optimum/neuron/trainers.py | 13 ++++++++++-- optimum/neuron/utils/hub_neuronx_cache.py | 8 ++++--- optimum/neuron/utils/training_utils.py | 16 +++++++++++--- .../distributed/test_model_parallelization.py | 3 +++ 7 files changed, 58 insertions(+), 23 deletions(-) diff --git a/optimum/commands/neuron/cache.py b/optimum/commands/neuron/cache.py index 418226275..762602566 100644 --- a/optimum/commands/neuron/cache.py +++ b/optimum/commands/neuron/cache.py @@ -14,6 +14,7 @@ # limitations under the License. """Defines the command line related to dealing with the Neuron cache repo.""" +from pathlib import Path from typing import TYPE_CHECKING from ...neuron.utils import get_hub_cached_entries, synchronize_hub_cache @@ -23,6 +24,7 @@ create_custom_cache_repo, set_custom_cache_repo_name_in_hf_home, ) +from ...neuron.utils.require_utils import requires_torch_neuronx from ...neuron.utils.runner import ExampleRunner from ...utils import logging from ..base import BaseOptimumCLICommand, CommandInfo @@ -165,9 +167,21 @@ class SynchronizeRepoCommand(BaseOptimumCLICommand): @staticmethod def parse_args(parser: "ArgumentParser"): parser.add_argument("--repo_id", type=str, default=None, help="The name of the repo to use as remote cache.") + parser.add_argument( + "--cache_dir", type=str, default=None, help="The cache directory that contains the compilation files" + ) + @requires_torch_neuronx def run(self): - synchronize_hub_cache(self.args.repo_id) + from libneuronxla.neuron_cc_cache import CacheUrl + + if self.args.cache_dir is not None: + if not Path(self.args.cache_dir).is_dir(): + raise ValueError(f"The {self.args.cache_dir} directory does not exist.") + cache_url = CacheUrl(self.args.cache_dir, url_type="fs") + else: + cache_url = None + synchronize_hub_cache(cache_url=cache_url, cache_repo_id=self.args.repo_id) class LookupRepoCommand(BaseOptimumCLICommand): diff --git a/optimum/neuron/accelerate/state.py b/optimum/neuron/accelerate/state.py index 1b1fe8c6e..6ba710445 100644 --- a/optimum/neuron/accelerate/state.py +++ b/optimum/neuron/accelerate/state.py @@ -267,23 +267,11 @@ def __init__( os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_TP", "false") == "true" or os.environ.get("ACCELERATE_USE_NEURONX_DISTRIBUTED_PP", "false") == "true" ): - if not is_neuronx_distributed_available(): - raise RuntimeError( - "Model parallelism requires the neuronx_distributed package. You can install it by " - "running: python -m pip install neuronx_distributed --extra-index-url " - "https://pip.repos.neuron.amazonaws.com" - ) if mp_plugin is None: raise ValueError( - "Could not initialize `neuronx_distributed` model parallelism because no " - "`ModelParallelismPlugin` was provided." + "Could not initialize model parallelism because no `ModelParallelismPlugin` was provided." ) if mp_plugin.should_parallelize: - if not parallel_state.model_parallel_is_initialized(): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size=mp_plugin.tensor_parallel_size, - pipeline_model_parallel_size=mp_plugin.pipeline_parallel_size, - ) self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM else: logger.warning( @@ -293,6 +281,13 @@ def __init__( self.mp_plugin = mp_plugin else: self.mp_plugin = ModelParallelismPlugin() + + if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized(): + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size, + pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size, + ) + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true": self.distributed_type = NeuronDistributedType.XLA_FSDP if self._mixed_precision != "no": diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 914f47c5e..13cb99ab6 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -487,6 +487,7 @@ def parallelize( from neuronx_distributed.pipeline import NxDPPModel tp_size = get_tensor_model_parallel_size() + pp_size = get_pipeline_model_parallel_size() sequence_parallel_enabled = sequence_parallel_enabled and tp_size > 1 @@ -501,6 +502,8 @@ def parallelize( parameter_to_name = {p: n for n, p in name_to_parameter.items()} def should_parallelize_layer_predicate_func(layer): + if pp_size == 1: + return True for p in layer.parameters(): if p not in parameter_to_name: return True @@ -558,7 +561,6 @@ def should_parallelize_layer_predicate_func(layer): if is_main_worker(): logger.info("Load and initialization of the weights done.") - pp_size = get_pipeline_model_parallel_size() if pp_size > 1: if not cls.supports_pipeline_parallelism(): raise NotImplementedError("{cls} does not support pipeline parallelism.") diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 64250c5d8..525e54519 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -75,6 +75,7 @@ from .utils.cache_utils import ( get_hf_hub_cache_repos, get_model_name_or_path, + get_neuron_cache_path, get_neuronxcc_version, get_num_neuron_cores_used, has_write_access_to_repo, @@ -82,7 +83,7 @@ from .utils.hub_neuronx_cache import ModelCacheEntry, hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache from .utils.misc import is_main_worker from .utils.patching import patch_everywhere -from .utils.require_utils import requires_neuronx_distributed +from .utils.require_utils import requires_neuronx_distributed, requires_torch_neuronx from .utils.training_utils import ( TRANSFORMERS_MIN_VERSION_USE_ACCELERATE, get_model_param_count, @@ -269,12 +270,20 @@ def create_accelerator_and_postprocess(self): ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config ds_plugin.hf_ds_config.trainer_config_process(self.args) + @requires_torch_neuronx def synchronize_hub_cache(self): + from libneuronxla.neuron_cc_cache import CacheUrl + repo_id = get_hf_hub_cache_repos()[0] if xm.get_ordinal() == 0: has_write_access = has_write_access_to_repo(repo_id) if has_write_access: - synchronize_hub_cache(repo_id) + cache_path = get_neuron_cache_path() + if cache_path is not None: + cache_url = CacheUrl(cache_path.as_posix(), url_type="fs") + else: + cache_url = None + synchronize_hub_cache(cache_url=cache_url, cache_repo_id=repo_id) xm.rendezvous("Hub cache synchronization done") def _wrap_model(self, model, training=True, dataloader=None): diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index d521ffee7..e68f6f11b 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -332,14 +332,16 @@ def patch_neuron_cc_wrapper(): @requires_torch_neuronx -def synchronize_hub_cache(cache_repo_id: Optional[str] = None): +def synchronize_hub_cache(cache_url: Optional[CacheUrl] = None, cache_repo_id: Optional[str] = None): """Synchronize the neuronx compiler cache with the optimum-neuron hub cache. Args: - repo_id (`Optional[str]`, default to None): + cache_url (`Optional[CacheUrl]`, defaults to `None`): + The cache url to use for synchronization. + cache_repo_id (`Optional[str]`, default to None): The id of the HuggingFace cache repository, in the form 'org|user/name'. """ - hub_cache_proxy = _create_hub_compile_cache_proxy(cache_repo_id=cache_repo_id) + hub_cache_proxy = _create_hub_compile_cache_proxy(cache_url=cache_url, cache_repo_id=cache_repo_id) hub_cache_proxy.synchronize() diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 47e573f2a..21110bf61 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -405,20 +405,27 @@ def get_model_param_count(model: Union[torch.nn.Module, "NxDPPModel"], trainable named_parameters = model.named_parameters() shared_parameters_across_pipeline_stages = {} - pp_rank = get_pipeline_model_parallel_rank() + if torch.distributed.is_initialized(): + tp_size = get_tensor_model_parallel_size() + pp_size = get_pipeline_model_parallel_size() + pp_rank = get_pipeline_model_parallel_rank() + else: + tp_size = 1 + pp_size = 1 + pp_rank = 0 def numel(parameter_name, parameter) -> int: should_count_param = shared_parameters_across_pipeline_stages.get(parameter_name, pp_rank) == pp_rank num_elements = parameter.numel() if getattr(parameter, "tensor_model_parallel", False): - num_elements *= get_tensor_model_parallel_size() + num_elements *= tp_size return num_elements if should_count_param else 0 param_count = sum(numel(n, p) for n, p in named_parameters if not trainable_only or p.requires_grad) - if get_pipeline_model_parallel_size() > 1: + if pp_size > 1: param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device()) param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True)) param_count = int(param_count.detach().item()) @@ -436,6 +443,9 @@ def is_main_worker_for_metrics() -> bool: get_tensor_model_parallel_rank, ) + if not torch.distributed.is_initialized(): + return True + dp_rank = get_data_parallel_rank() tp_rank = get_tensor_model_parallel_rank() pp_rank = get_pipeline_model_parallel_rank() diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index a7097dc4c..4bb9abaaf 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -385,6 +385,9 @@ def _parallel_model_matches_original_model( model = accelerator.patch_model_for_neuron(model) with torch.no_grad(): if pp_size == 1: + # This is set to False by `accelerator.prepare`, which we want in the general case, but here let's + # enable the cache to test that the KV cache matches the original model. + model.config.use_cache = True model = model.eval() model_outputs = model(**xla_inputs) else: