From c7a1f4bee4fbad95703f2055e7c7ee40f053194b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 9 Apr 2024 16:47:15 +0200 Subject: [PATCH] [WIP] integrate new API for saving and loading --- optimum/neuron/accelerate/accelerator.py | 8 +- .../neuron/accelerate/utils/dataclasses.py | 2 + optimum/neuron/distributed/base.py | 174 ++++-------------- optimum/neuron/distributed/checkpointing.py | 8 +- optimum/neuron/distributed/utils.py | 2 +- optimum/neuron/trainers.py | 8 +- optimum/neuron/training_args.py | 17 ++ tests/distributed/test_common.py | 4 +- 8 files changed, 80 insertions(+), 143 deletions(-) diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index a25a23a26..6b61aa345 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -683,7 +683,13 @@ def save_state_for_mp(self, output_dir: Optional[str] = None, **save_model_func_ def save_optimizer_func(accelerator, optimizer, model, output_dir, i): logger.info("Saving parallel model and optimizer") parallelizer = ParallelizersManager.parallelizer_for_model(model) - parallelizer.save_model_checkpoint(model, output_dir, as_regular=False, optimizer=optimizer) + parallelizer.save_model_sharded_checkpoint( + model, + output_dir, + optimizer=optimizer, + use_xser=self.state.mp_plugin.use_xser, + async_save=self.state.mp_plugin.async_save, + ) logger.info(f"Parallel model and optimizer saved to the directory {output_dir}") return self._custom_save_state( diff --git a/optimum/neuron/accelerate/utils/dataclasses.py b/optimum/neuron/accelerate/utils/dataclasses.py index 325b7a088..a94c3b45f 100644 --- a/optimum/neuron/accelerate/utils/dataclasses.py +++ b/optimum/neuron/accelerate/utils/dataclasses.py @@ -160,6 +160,8 @@ class ModelParallelismPlugin: gradient_checkpointing: bool = False checkpoint_dir: Optional[Union[str, Path]] = None num_ranks_per_loading_step: int = -1 + use_xser: bool = True + async_save: bool = False def __post_init__(self): if self.tensor_parallel_size < 1: diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index d0c73ce4c..cebcb0440 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -17,7 +17,6 @@ import contextlib import gc import math -import shutil from abc import ABC, abstractclassmethod from collections import defaultdict from dataclasses import asdict @@ -27,7 +26,6 @@ import torch from transformers import PreTrainedModel -from transformers.utils import WEIGHTS_NAME from ...utils import logging from ..utils import is_neuronx_distributed_available, is_torch_xla_available @@ -42,7 +40,7 @@ SequenceCollectiveOpInfo, ) from .utils import ( - TENSOR_PARALLEL_SHARDS_DIR_NAME, + MODEL_PARALLEL_SHARDS_DIR_NAME, OptimumGQAQKVColumnParallelLinear, OptimumNeuronFXTracer, ParameterMetadata, @@ -762,7 +760,7 @@ def should_parallelize_layer_predicate_func(layer): # TODO: can we optimize by skipping initialization and weight loading when `checkpoint_dir` is not None. if not is_precompilation() and checkpoint_dir is not None: - cls.load_model_checkpoint(model, checkpoint_dir) + cls.load_model_sharded_checkpoint(model, checkpoint_dir) model._gqa_qkv_metadata = gqa_qkv_metadata @@ -919,162 +917,70 @@ def _get_parameters_tp_metadata(cls, named_parameters: Dict[str, "torch.nn.Param @classmethod @requires_neuronx_distributed - def save_model_checkpoint_as_regular( - cls, - model: "PreTrainedModel", - output_dir: Union[str, Path], - optimizer: Optional["torch.optim.Optimizer"] = None, - ): - import neuronx_distributed - import torch_xla.core.xla_model as xm - from neuronx_distributed.parallel_layers.parallel_state import ( - get_data_parallel_rank, - get_tensor_model_parallel_rank, - ) - - cls._check_model_was_parallelized(model) - - data_parallel_rank = get_data_parallel_rank() - tensor_parallel_rank = get_tensor_model_parallel_rank() - - if data_parallel_rank != 0: - return - - if not isinstance(output_dir, Path): - output_dir = Path(output_dir) - - if is_main_worker() and optimizer is not None: - logger.warning( - "Saving the optimizer state as a regular file under the tensor parallel setting is not supported yet." - ) - - state_dict = {} - for name, param in model.named_parameters(): - if getattr(param, "tensor_model_parallel", False): - if param.partition_dim == 1: - tensor = neuronx_distributed.utils.gather_from_tensor_model_parallel_region(param) - else: - # Because the gather works only on last dim. Need to make it work for all dims. - tensor = neuronx_distributed.utils.gather_from_tensor_model_parallel_region( - param.transpose() - ).transpose() - else: - tensor = param - state_dict[name] = tensor - - model_state_dict = {"model": state_dict} - should_save = tensor_parallel_rank == 0 - xm._maybe_convert_to_cpu(model_state_dict, convert=should_save) - if should_save: - output_path = output_dir / WEIGHTS_NAME - torch.save(model_state_dict["model"], output_path.as_posix()) - xm.rendezvous("saving regular checkpoint") - - @classmethod - @requires_neuronx_distributed - def save_model_checkpoint_as_sharded( + def save_model_sharded_checkpoint( cls, model: Union["PreTrainedModel", "NxDPPModel"], output_dir: Union[str, Path], optimizer: Optional["torch.optim.Optimizer"] = None, + use_xser: bool = True, + async_save: bool = False, ): - import torch_xla.core.xla_model as xm - from neuronx_distributed import parallel_layers - from neuronx_distributed.pipeline import NxDPPModel + import neuronx_distributed cls._check_model_was_parallelized(model) if not isinstance(output_dir, Path): output_dir = Path(output_dir) - if isinstance(model, NxDPPModel): - model_state_dict = model.local_state_dict() - else: - model_state_dict = model.state_dict() - - state_dict = {"model": model_state_dict} - state_dict["sharded_metadata"] = { + metadata = {} + metadata["sharded_metadata"] = { k: asdict(v) for k, v in cls._get_parameters_tp_metadata(dict(model.named_parameters())).items() } - state_dict["gqa_qkv_metadata"] = model._gqa_qkv_metadata - - if optimizer is not None: - # TODO: have metadata working for the optimizer. - state_dict["optimizer_state_dict"] = optimizer.state_dict() - - output_path = output_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME - - if is_main_worker(): - if output_path.is_dir(): - shutil.rmtree(output_path, ignore_errors=True) - output_path.mkdir() - xm.rendezvous("waiting before saving") - parallel_layers.save(state_dict, output_path.as_posix(), save_xser=True) + metadata["gqa_qkv_metadata"] = model._gqa_qkv_metadata + + neuronx_distributed.trainer.save_checkpoint( + output_dir.as_posix(), + tag=MODEL_PARALLEL_SHARDS_DIR_NAME, + model=model, + optimizer=optimizer, + user_content=metadata, + use_xser=use_xser, + async_save=async_save, + ) @classmethod - def save_model_checkpoint( + @requires_neuronx_distributed + def load_sharded_checkpoint( cls, - model: "PreTrainedModel", - output_dir: Union[str, Path], - as_regular: bool = False, - as_sharded: bool = True, - optimizer: Optional["torch.optim.Optimizer"] = None, + load_dir: Union[str, Path], + model: Optional["PreTrainedModel"] = None, + optimizer: Optional[torch.optim.Optimizer] = None, ): - if not as_regular and not as_sharded: - raise ValueError("At least as_regular or as_sharded must be True.") - if as_regular: - cls.save_model_checkpoint_as_regular(model, output_dir, optimizer=optimizer) - if as_sharded: - cls.save_model_checkpoint_as_sharded(model, output_dir, optimizer=optimizer) - - @classmethod - @requires_neuronx_distributed - def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]): import neuronx_distributed - cls._check_model_was_parallelized(model) + if model is None and optimizer is None: + raise ValueError("At least a model or an optimizer must be provided") - if not isinstance(load_dir, Path): - load_dir = Path(load_dir) - neuronx_distributed.parallel_layers.load( - load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME, - model_or_optimizer=model, - load_xser=True, - sharded=True, - ) + if model is not None: + cls._check_model_was_parallelized(model) - @classmethod - def load_model_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]): if not isinstance(load_dir, Path): load_dir = Path(load_dir) - if (load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir(): - cls.load_model_sharded_checkpoint(model, load_dir) - else: + if not (load_dir / MODEL_PARALLEL_SHARDS_DIR_NAME).is_dir(): raise FileNotFoundError(f"Could not find a sharded checkpoint directory under {load_dir.as_posix()}.") - @classmethod - @requires_neuronx_distributed - def load_optimizer_sharded_checkpoint(cls, optimizer: "torch.optim.Optimizer", load_dir: Union[str, Path]): - import neuronx_distributed - from neuronx_distributed.optimizer import NeuronZero1Optimizer - - is_zero_1_optimizer = optimizer.__class__.__name__ == "NeuronAcceleratedOptimizer" and isinstance( - optimizer.optimizer, NeuronZero1Optimizer + neuronx_distributed.trainer.load_checkpoint( + load_dir.as_posix(), + tag=MODEL_PARALLEL_SHARDS_DIR_NAME, + model=model, + optimizer=optimizer, ) - is_zero_1_optimizer = is_zero_1_optimizer or isinstance(optimizer, NeuronZero1Optimizer) - if is_zero_1_optimizer: - raise NotImplementedError( - "It is not possible to load a sharded optimizer checkpoint when using ZeRO-1 yet." - ) - if not isinstance(load_dir, Path): - load_dir = Path(load_dir) + @classmethod + def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]): + return cls.load_sharded_checkpoint(load_dir, model=model) - neuronx_distributed.parallel_layers.load( - load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME, - model_or_optimizer=optimizer, - model_key="optimizer_state_dict", - load_xser=True, - sharded=True, - ) + @classmethod + def load_optimizer_sharded_checkpoint(cls, optimizer: "torch.optim.Optimizer", load_dir: Union[str, Path]): + return cls.load_sharded_checkpoint(load_dir, optimizer=optimizer) diff --git a/optimum/neuron/distributed/checkpointing.py b/optimum/neuron/distributed/checkpointing.py index 97a6128e3..fddf7604d 100644 --- a/optimum/neuron/distributed/checkpointing.py +++ b/optimum/neuron/distributed/checkpointing.py @@ -23,7 +23,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME from ..utils.require_utils import requires_neuronx_distributed, requires_safetensors -from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank +from .utils import MODEL_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, compute_query_indices_for_rank def create_gqa_query_or_output_projection_weight_from_full_weight( @@ -136,9 +136,9 @@ def consolidate_model_parallel_checkpoints(checkpoint_dir: Union[str, Path]) -> if not isinstance(checkpoint_dir, Path): checkpoint_dir = Path(checkpoint_dir) - if checkpoint_dir.name != TENSOR_PARALLEL_SHARDS_DIR_NAME: - if (checkpoint_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir(): - checkpoint_dir = checkpoint_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME + if checkpoint_dir.name != MODEL_PARALLEL_SHARDS_DIR_NAME: + if (checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME).is_dir(): + checkpoint_dir = checkpoint_dir / MODEL_PARALLEL_SHARDS_DIR_NAME else: raise ValueError(f"Could not find the tensor parallel shards from {checkpoint_dir}") diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index cfef542d9..d87d14d43 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -63,7 +63,7 @@ def __init__(self, *args, **kwargs): logger = logging.get_logger() -TENSOR_PARALLEL_SHARDS_DIR_NAME = "tensor_parallel_shards" +MODEL_PARALLEL_SHARDS_DIR_NAME = "shards" @deprecate( diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 014e229ad..3388c4de8 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -488,7 +488,13 @@ def _save_xla(self, output_dir: Optional[str] = None): # This mark_step is needed to avoid hang issues. xm.mark_step() - Parallelizer.save_model_checkpoint(self.model, output_dir, as_sharded=True, optimizer=self.optimizer) + Parallelizer.save_model_sharded_checkpoint( + self.model, + output_dir, + optimizer=self.optimizer, + use_xser=self.accelerator.state.mp_plugin.use_xser, + async_save=self.accelerator.state.mp_plugin.async_save, + ) else: safe_save_function_patcher = Patcher( [("transformers.modeling_utils.safe_save_file", torch_xla_safe_save_file)] diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index 051b8289c..f77065e3c 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -114,6 +114,21 @@ class NeuronTrainingArgumentsMixin: ) }, ) + use_xser: bool = field( + default=True, + metadata={ + "help": "Whether to use `torch-xla` serialization when saving checkpoints when doing model parallelism" + }, + ) + async_save: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use asynchronous saving method when doing model parallelism. It can boost saving " + "performance but will result in more host memory usage, increasing the risk of going OOM." + ) + }, + ) def __post_init__(self): # Patches accelerate.utils.imports.is_tpu_available to match `is_torch_xla_available` @@ -165,6 +180,8 @@ def __post_init__(self): gradient_checkpointing=self.gradient_checkpointing, checkpoint_dir=resume_from_checkpoint, num_ranks_per_loading_step=self.num_ranks_per_loading_step, + use_xser=self.use_xser, + async_save=self.async_save, ) if self.bf16 and self.half_precision_backend == "amp": diff --git a/tests/distributed/test_common.py b/tests/distributed/test_common.py index aa9c44982..9c6ce7a8c 100644 --- a/tests/distributed/test_common.py +++ b/tests/distributed/test_common.py @@ -26,7 +26,7 @@ from optimum.neuron.accelerate.utils.dataclasses import NeuronDistributedType from optimum.neuron.distributed.checkpointing import consolidate_model_parallel_checkpoints_to_unified_checkpoint from optimum.neuron.distributed.utils import ( - TENSOR_PARALLEL_SHARDS_DIR_NAME, + MODEL_PARALLEL_SHARDS_DIR_NAME, make_optimizer_constructor_lazy, ) from optimum.neuron.utils.import_utils import ( @@ -364,7 +364,7 @@ def test_save_model_and_load_model(self, parallel_sizes, tmpdir, monkeypatch): tensors_directory = f"{ref_data_file_name}.tensors" assert not pytorch_checkpoint_exists assert not safetensors_checkpoint_exists - assert TENSOR_PARALLEL_SHARDS_DIR_NAME in tmpdir_content + assert MODEL_PARALLEL_SHARDS_DIR_NAME in tmpdir_content assert ref_data_file_name in tmpdir_content assert tensors_directory in tmpdir_content else: