diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index bc0f31772..17926b240 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -343,9 +343,13 @@ def prepare_model_for_xla_fsdp( def _prepare_model_for_tp( self, model: torch.nn.Module, device_placement: Optional[bool] = None, evaluation_mode: bool = False ): + if model in self._models or Parallelizer.was_parallelized(model): + return model + cpu_ids = [id(v) for v in model.parameters()] # TODO: enable self.device (if needed). model = self.state.tp_plugin.parallelize_model(model, device=None) + if os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1": model.to(torch.bfloat16) else: diff --git a/optimum/neuron/accelerate/utils/dataclasses.py b/optimum/neuron/accelerate/utils/dataclasses.py index 7cf5e04fa..d5ade238a 100644 --- a/optimum/neuron/accelerate/utils/dataclasses.py +++ b/optimum/neuron/accelerate/utils/dataclasses.py @@ -17,6 +17,7 @@ import enum import os from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union import torch @@ -143,10 +144,13 @@ class TensorParallelismPlugin: tensor_parallel_size: int = 1 parallelize_embeddings: bool = True sequence_parallel_enabled: bool = False + checkpoint_dir: Optional[Union[str, Path]] = None def __post_init__(self): if self.tensor_parallel_size < 1: raise ValueError(f"The tensor parallel size must be >= 1, but {self.tensor_parallel_size} was given here.") + if isinstance(self.checkpoint_dir, str): + self.checkpoint_dir = Path(self.checkpoint_dir) @property def should_parallelize(self): @@ -163,5 +167,6 @@ def parallelize_model( device=device, parallelize_embeddings=self.parallelize_embeddings, sequence_parallel_enabled=self.sequence_parallel_enabled, + checkpoint_dir=self.checkpoint_dir, ) return parallelized_model diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index e2ac71fff..250aa2461 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -15,6 +15,7 @@ """Base class related to `neuronx_distributed` to perform parallelism.""" import contextlib +import gc import shutil from abc import ABC, abstractclassmethod from dataclasses import asdict @@ -28,6 +29,7 @@ from ...utils import logging from ..utils import is_neuronx_distributed_available, is_torch_xla_available from ..utils.deprecate_utils import deprecate +from ..utils.require_utils import requires_neuronx_distributed from .parallel_layers import ( IOSequenceParallelizer, LayerNormSequenceParallelizer, @@ -37,14 +39,6 @@ from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, WeightInformation, load_tensor_for_weight -if is_neuronx_distributed_available(): - import neuronx_distributed - from neuronx_distributed import parallel_layers - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - if TYPE_CHECKING: from transformers import PreTrainedModel @@ -164,12 +158,14 @@ def patch_for_sequence_parallelism(cls, model: "PreTrainedModel", sequence_paral ) @classmethod + @requires_neuronx_distributed def parallelize( cls, model: "PreTrainedModel", device: Optional["torch.device"] = None, parallelize_embeddings: bool = True, sequence_parallel_enabled: bool = False, + checkpoint_dir: Optional[Union[str, Path]] = None, ) -> "PreTrainedModel": """ Parallelizes the model by transforming regular layer into their parallel counterparts using @@ -188,6 +184,9 @@ def parallelize( This can be disabled in the case when the TP size does not divide the vocabulary size. sequence_parallel_enabled (`bool`, defaults to `False`): Whether or not sequence parallelism is enabled. + checkpoint_dir (`Optional[Union[str, Path]]`): + Path to a sharded checkpoint. If specified, the checkpoint weights will be loaded to the parallelized + model. Returns: `PreTrainedModel`: The parallelized model. @@ -195,6 +194,8 @@ def parallelize( if sequence_parallel_enabled and cls.SEQUENCE_PARALLEL_LAYERNORM_PATTERNS is None: raise NotImplementedError(f"Sequence parallelism is not supported for {model.__class__}.") + from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank + # Preparing the model for sequence parallelism: # 1. Transforming the LayerNorms. layer_norm_qualified_name_patterns = ( @@ -259,7 +260,7 @@ def parallelize( # parallelization since those are the only classes that we initialize on the `meta` device. num_dims = current_weight.dim() partition_dim = getattr(current_weight, "partition_dim") - tp_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank() + tp_rank = get_tensor_model_parallel_rank() size_per_rank = current_weight.size(partition_dim) slices = [ None @@ -298,6 +299,9 @@ def parallelize( # `reset_parameters()` method. mod.reset_parameters() + if checkpoint_dir is not None: + cls.load_model_checkpoint(model, checkpoint_dir) + return model @classmethod @@ -305,7 +309,10 @@ def deparallelize(cls, model: "PreTrainedModel") -> "PreTrainedModel": raise NotImplementedError @classmethod + @requires_neuronx_distributed def was_parallelized(cls, model: "PreTrainedModel") -> bool: + from neuronx_distributed import parallel_layers + parallel_layer_classes = ( parallel_layers.ParallelEmbedding, parallel_layers.ColumnParallelLinear, @@ -410,6 +417,7 @@ def _get_parameters_tp_metadata(cls, named_parameters: Dict[str, "torch.nn.Param return tp_metadata @classmethod + @requires_neuronx_distributed def save_model_checkpoint_as_regular( cls, model: "PreTrainedModel", @@ -417,8 +425,16 @@ def save_model_checkpoint_as_regular( optimizer: Optional["torch.optim.Optimizer"] = None, ): cls._check_model_was_parallelized(model) - data_parallel_rank = parallel_layers.parallel_state.get_data_parallel_rank() - tensor_parallel_rank = parallel_layers.parallel_state.get_tensor_model_parallel_rank() + + import neuronx_distributed + import torch_xla.core.xla_model as xm + from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_rank, + get_tensor_model_parallel_rank, + ) + + data_parallel_rank = get_data_parallel_rank() + tensor_parallel_rank = get_tensor_model_parallel_rank() if data_parallel_rank != 0: return @@ -454,6 +470,7 @@ def save_model_checkpoint_as_regular( xm.rendezvous("saving regular checkpoint") @classmethod + @requires_neuronx_distributed def save_model_checkpoint_as_sharded( cls, model: "PreTrainedModel", @@ -462,6 +479,8 @@ def save_model_checkpoint_as_sharded( ): cls._check_model_was_parallelized(model) + import torch_xla.core.xla_model as xm + from neuronx_distributed import parallel_layers from neuronx_distributed.parallel_layers.parallel_state import ( get_data_parallel_rank, get_tensor_model_parallel_rank, @@ -508,12 +527,11 @@ def save_model_checkpoint( cls.save_model_checkpoint_as_sharded(model, output_dir, optimizer=optimizer) @classmethod - def load_model_regular_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]): - raise NotImplementedError("This requires being able to deparallelize the model.") - - @classmethod + @requires_neuronx_distributed def load_model_sharded_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Path]): cls._check_model_was_parallelized(model) + from neuronx_distributed import parallel_layers + if not isinstance(load_dir, Path): load_dir = Path(load_dir) parallel_layers.load(load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME, model=model, sharded=True) @@ -525,7 +543,54 @@ def load_model_checkpoint(cls, model: "PreTrainedModel", load_dir: Union[str, Pa if (load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir(): cls.load_model_sharded_checkpoint(model, load_dir) - elif (load_dir / WEIGHTS_NAME).is_file(): - cls.load_model_regular_checkpoint(model, load_dir) else: - raise FileNotFoundError(f"Could not find a checkpoint file under {load_dir.as_posix()}.") + raise FileNotFoundError(f"Could not find a sharded checkpoint directory under {load_dir.as_posix()}.") + + @classmethod + @requires_neuronx_distributed + def load_optimizer_sharded_checkpoint(cls, optimizer: "torch.optim.Optimizer", load_dir: Union[str, Path]): + from neuronx_distributed.optimizer import NeuronZero1Optimizer + + is_zero_1_optimizer = optimizer.__class__.__name__ == "NeuronAcceleratedOptimizer" and isinstance( + optimizer.optimizer, NeuronZero1Optimizer + ) + is_zero_1_optimizer = is_zero_1_optimizer or isinstance(optimizer, NeuronZero1Optimizer) + if is_zero_1_optimizer: + raise NotImplementedError( + "It is not possible to load a sharded optimizer checkpoint when using ZeRO-1 yet." + ) + + if not isinstance(load_dir, Path): + load_dir = Path(load_dir) + + import torch_xla.core.xla_model as xm + from neuronx_distributed.parallel_layers.parallel_state import ( + get_pipeline_model_parallel_rank, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_size, + ) + + world_size = get_tensor_model_parallel_size() + tp_rank = get_tensor_model_parallel_rank() + pp_rank = get_pipeline_model_parallel_rank() + + if not (load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME).is_dir(): + raise FileNotFoundError(f"Could not find a sharded checkpoint directory under {load_dir.as_posix()}.") + + checkpoint_name = load_dir / TENSOR_PARALLEL_SHARDS_DIR_NAME / f"tp_rank_{tp_rank:02d}_pp_rank{pp_rank:02d}.pt" + + device = "xla" + for group in optimizer.param_groups: + for p in group["params"]: + device = p.device + break + + for worker_start in range(0, world_size): + if tp_rank == worker_start: + checkpoint = torch.load(checkpoint_name, map_location="cpu") + optimizer_state_dict = checkpoint["optimizer_state_dict"] + xm.send_cpu_data_to_device(optimizer_state_dict, device) + optimizer.load_state_dict(optimizer_state_dict) + del checkpoint + gc.collect() + xm.rendezvous("neuron.load_checkpoint" + str(worker_start)) diff --git a/optimum/neuron/generation/utils.py b/optimum/neuron/generation/utils.py index 8a24a5c80..ce6f93e8b 100644 --- a/optimum/neuron/generation/utils.py +++ b/optimum/neuron/generation/utils.py @@ -821,9 +821,7 @@ def generate( model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states if generation_config.use_cache: - warnings.warn( - "use_cache is not supported for generation on Neuron devices, switching to use_cache=False." - ) + warnings.warn("use_cache is not supported for generation on Neuron devices, switching to use_cache=False.") model_kwargs["use_cache"] = False accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 44f35ceb0..72047d479 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -29,6 +29,7 @@ from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments from transformers.dependency_versions_check import dep_version_check from transformers.integrations import is_fairscale_available +from transformers.modeling_utils import unwrap_model from transformers.trainer import ( OPTIMIZER_NAME, SCHEDULER_NAME, @@ -42,7 +43,7 @@ PREFIX_CHECKPOINT_DIR, EvalLoopOutput, ) -from transformers.utils import is_sagemaker_mp_enabled +from transformers.utils import WEIGHTS_NAME, is_sagemaker_mp_enabled from ..utils import check_if_transformers_greater, logging from .accelerate import NeuronAccelerator, NeuronDistributedType @@ -372,12 +373,61 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) - def _save_checkpoint_with_accelerator(self, model, trial, metrics=None): - if self.accelerator.distributed_type is NeuronDistributedType.XLA_FSDP and not self.is_fsdp_enabled: - # TODO: handle this case better? - # Do we want to fail here? Can we save anyway? - raise RuntimeError("Cannot save checkpoint if FSDP is not enabled.") + def _save_xla(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + logger.info(f"Saving model checkpoint to {output_dir}") + if xm.is_master_ordinal(): + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + xm.rendezvous("saving_checkpoint") + if self.accelerator.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM: + logger.info("Model parallelism is enabled, only saving the model sharded state dict.") + if isinstance(self.model, PreTrainedModel): + self.model.config.save_pretrained(output_dir) + + parallelizer = ParallelizersManager.parallelizer_for_model(self.model) + # This mark_step is needed to avoid hang issues. + xm.mark_step() + parallelizer.save_model_checkpoint(self.model, output_dir, as_sharded=True, optimizer=self.optimizer) + else: + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + unwrap_model(self.model).save_pretrained( + output_dir, + is_main_process=self.args.should_save, + state_dict=self.model.state_dict(), + save_function=xm.save, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = self.model.state_dict() + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + + if self.tokenizer is not None and self.args.should_save: + self.tokenizer.save_pretrained(output_dir) + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + if output_dir is None: + output_dir = self.args.output_dir + + self._save_xla(output_dir) + + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save") + + def _save_checkpoint(self, model, trial, metrics=None): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" if self.hp_search_backend is None and trial is None: @@ -385,25 +435,20 @@ def _save_checkpoint_with_accelerator(self, model, trial, metrics=None): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) - os.makedirs(output_dir, exist_ok=True) if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) - torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) - if self.tokenizer is not None and self.args.should_save: - self.tokenizer.save_pretrained(output_dir) - - if isinstance(self.model, PreTrainedModel): - self.model.config.save_pretrained(output_dir) + self.save_model(output_dir, _internal_call=True) - self.accelerator.save_state(output_dir) + # The optimizer state is saved in the shard alongside with the model parameters when doing TP. + if self.accelerator.distributed_type is not NeuronDistributedType.TENSOR_PARALLELISM: + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) - # Save scaler - # TODO: is grad scaling supported with TORCH XLA? - # reissue_pt_warnings(caught_warnings) - # if self.do_grad_scaling: - # xm.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: @@ -432,8 +477,7 @@ def _save_checkpoint_with_accelerator(self, model, trial, metrics=None): "cpu": torch.random.get_rng_state(), } - if is_torch_xla_available(): - rng_states["xla"] = xm.get_rng_state() + rng_states["xla"] = xm.get_rng_state() # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # not yet exist. @@ -451,28 +495,13 @@ def _save_checkpoint_with_accelerator(self, model, trial, metrics=None): if self.args.should_save: self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) - def _save_checkpoint(self, model, trial, metrics=None): - if check_if_transformers_greater("4.30.0") and self.accelerator.distributed_type in [ - NeuronDistributedType.XLA_FSDP, - NeuronDistributedType.TENSOR_PARALLELISM, - ]: - return self._save_checkpoint_with_accelerator(model, trial, metrics=metrics) - return super()._save_checkpoint(model, trial, metrics=metrics) - - def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): - if output_dir is None: - output_dir = self.args.output_dir - if self.accelerator.distributed_type is NeuronDistributedType.XLA_FSDP: - self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir, 0) - elif self.accelerator.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM: - parallelizer = ParallelizersManager.parallelizer_for_model(self.model) - parallelizer.save_model_checkpoint(self.model, output_dir, as_regular=False) - else: - return super().save_model(output_dir=output_dir, _internal_call=_internal_call) + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + # It has been handled during model parallelization. + if self.accelerator.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM: + return + super()._load_from_checkpoint(self, resume_from_checkpoint, model=model) def _load_optimizer_and_scheduler_for_xla_fsdp(self, checkpoint): - if checkpoint is None: - return checkpoint_file_exists = ( glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") if is_sagemaker_mp_enabled() @@ -490,9 +519,19 @@ def _load_optimizer_and_scheduler_for_xla_fsdp(self, checkpoint): # TODO: load grad scaling? def _load_optimizer_and_scheduler(self, checkpoint): - if self.fsdp or self.is_fsdp_enabled: + if checkpoint is None: + return + if self.accelerator.distributed_type is NeuronDistributedType.XLA_FSDP: return self._load_optimizer_and_scheduler_for_xla_fsdp(checkpoint) - return super()._load_optimizer_and_scheduler(checkpoint) + elif self.accelerator.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM: + lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") + xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) + self.lr_scheduler.load_state_dict(lr_scheduler_state) + + parallelizer = ParallelizersManager.parallelizer_for_model(self.model) + parallelizer.load_optimizer_sharded_checkpoint(self.optimizer, checkpoint) + else: + return super()._load_optimizer_and_scheduler(checkpoint) @patch_within_function(("transformers.trainer.skip_first_batches", skip_first_batches)) def _inner_training_loop( diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index 911858cc1..f9d8d2dfc 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -24,6 +24,7 @@ import torch from accelerate.utils import DistributedType from packaging import version +from transformers.trainer_utils import get_last_checkpoint from transformers.training_args import ParallelMode, TrainingArguments from transformers.training_args_seq2seq import Seq2SeqTrainingArguments from transformers.utils import ( @@ -97,7 +98,19 @@ def __post_init__(self): "The minimal required Transformers version to perform XLA FSDP is " f"{TRANSFORMERS_MIN_VERSION_FOR_XLA_FSDP} but {transformers.__version__} is installed." ) - self.tp_plugin = TensorParallelismPlugin(self.tensor_parallel_size, not self.disable_embedding_parallelization) + + resume_from_checkpoint = self.resume_from_checkpoint + if resume_from_checkpoint is None and os.path.isdir(self.output_dir): + # If checkpoint is None, then there was no checkpoint in output dir, otherwise we use it. + checkpoint = get_last_checkpoint(self.output_dir) + resume_from_checkpoint = checkpoint + + self.tp_plugin = TensorParallelismPlugin( + self.tensor_parallel_size, + not self.disable_embedding_parallelization, + sequence_parallel_enabled=self.sequence_parallel_enabled, + checkpoint_dir=resume_from_checkpoint, + ) super().__post_init__() # Needed only to specialize the warning message for FSDP. diff --git a/optimum/neuron/utils/runner.py b/optimum/neuron/utils/runner.py index 3b163fd6f..d0c262056 100644 --- a/optimum/neuron/utils/runner.py +++ b/optimum/neuron/utils/runner.py @@ -185,7 +185,7 @@ def __init__( task: str, example_dir: Optional[Union[str, Path]] = None, config_overrides: Optional[Dict[str, Any]] = None, - use_venv: bool = True, + use_venv: bool = False, install_requirements: bool = True, ): self.model_name_or_path = model_name_or_path @@ -383,6 +383,7 @@ def run( max_eval_samples: Optional[int] = None, logging_steps: int = 1, save_steps: int = -1, + save_total_limit: int = -1, learning_rate: float = 1e-4, tensor_parallel_size: int = 1, disable_embedding_parallelization: bool = False, @@ -390,6 +391,7 @@ def run( output_dir: Optional[Union[Path, str]] = None, do_precompilation: bool = False, print_outputs: bool = False, + resume_from_checkpoint: Optional[Union[str, Path]] = None, _disable_is_private_model_repo_check: bool = False, ) -> Tuple[int, str]: if num_cores <= 0 or num_cores > 32: @@ -453,7 +455,7 @@ def compute_max_train_samples( if do_eval: cmd.append("--do_eval") if max_eval_samples is not None: - cmd.append("--max_eval_samples {max_eval_samples}") + cmd.append(f"--max_eval_samples {max_eval_samples}") cmd.append(f"--learning_rate {learning_rate}") cmd.append(f"--per_device_train_batch_size {train_batch_size}") if do_eval: @@ -468,7 +470,7 @@ def compute_max_train_samples( cmd.append(f"--logging_steps {logging_steps}") cmd.append("--save_strategy steps") cmd.append(f"--save_steps {save_steps}") - cmd.append("--save_total_limit 1") + cmd.append(f"--save_total_limit {save_total_limit}") # Parallelism if tensor_parallel_size > 1: @@ -518,6 +520,9 @@ def split_args_and_value_in_command(cmd: List[str]) -> List[str]: else: cmd.append(f"--output_dir {output_dir}") + if resume_from_checkpoint is not None: + cmd.append(f"--resume_from_checkpoint {resume_from_checkpoint}") + env = dict(os.environ) if _disable_is_private_model_repo_check: env["OPTIMUM_NEURON_DISABLE_IS_PRIVATE_REPO_CHECK"] = "true" diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index daf7cf817..55031438d 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -286,7 +286,7 @@ def skip_first_batches(dataloader, num_batches=0): """ import torch_xla.distributed.parallel_loader as pl - if isinstance(dataloader, (pl.ParallelLoader, pl.PerDeviceLoader)): + if isinstance(dataloader, (pl.ParallelLoader, pl.PerDeviceLoader, pl.MpDeviceLoader)): dataloader._loader = skip_first_batches(dataloader._loader, num_batches=num_batches) else: dataloader = accelerate_skip_first_batches(dataloader, num_batches=num_batches) diff --git a/tests/distributed/test_training.py b/tests/distributed/test_training.py new file mode 100644 index 000000000..3fd19ed7c --- /dev/null +++ b/tests/distributed/test_training.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests related to training with `neuronx_distributed`.""" + +import os +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import TestCase + +from huggingface_hub import HfFolder + +from optimum.neuron.utils.cache_utils import ( + load_custom_cache_repo_name_from_hf_home, + set_custom_cache_repo_name_in_hf_home, +) +from optimum.neuron.utils.runner import ExampleRunner +from optimum.neuron.utils.testing_utils import is_trainium_test + + +_TINY_BERT_MODEL_NAME = "hf-internal-testing/tiny-random-bert" + + +@is_trainium_test +class DistributedTrainingTestCase(TestCase): + CACHE_REPO_NAME = "optimum-internal-testing/optimum-neuron-cache-for-testing" + + @classmethod + def setUpClass(cls) -> None: + orig_token = HfFolder.get_token() + orig_cache_repo = load_custom_cache_repo_name_from_hf_home() + ci_token = os.environ.get("HF_TOKEN_OPTIMUM_NEURON_CI", None) + if ci_token is not None: + HfFolder.save_token(ci_token) + set_custom_cache_repo_name_in_hf_home(cls.CACHE_REPO_NAME) + cls._token = orig_token + cls._cache_repo = orig_cache_repo + + @classmethod + def tearDownClass(cls) -> None: + if cls._token is not None: + HfFolder.save_token(cls._token) + if cls._cache_repo is not None: + set_custom_cache_repo_name_in_hf_home(cls._cache_repo) + + def test_tp_save_and_resume_from_checkpoint(self): + num_cores = 8 + precision = "bf16" + tensor_parallel_size = 2 + train_batch_size = 2 + eval_batch_size = 2 + sequence_length = 16 + max_steps = 10 + save_steps = 2 + do_eval = True + max_eval_samples = 16 + + with TemporaryDirectory() as tmpdirname: + output_dir = Path(tmpdirname) + + runner = ExampleRunner(_TINY_BERT_MODEL_NAME, "text-classification") + + first_output_dir = output_dir / "first_run" + returncode, _ = runner.run( + num_cores, + precision, + train_batch_size, + eval_batch_size=eval_batch_size, + sequence_length=sequence_length, + tensor_parallel_size=tensor_parallel_size, + max_steps=max_steps, + save_steps=save_steps, + do_eval=do_eval, + max_eval_samples=max_eval_samples, + output_dir=first_output_dir, + print_outputs=True, + ) + assert returncode == 0, "First run failed." + + # Case 1: Resuming from checkpoint by specifying a checkpoint directory. + second_output_dir = output_dir / "second_run" + returncode, _ = runner.run( + num_cores, + precision, + train_batch_size, + eval_batch_size=eval_batch_size, + sequence_length=sequence_length, + tensor_parallel_size=tensor_parallel_size, + max_steps=max_steps, + save_steps=save_steps, + do_eval=do_eval, + max_eval_samples=max_eval_samples, + output_dir=second_output_dir, + resume_from_checkpoint=first_output_dir / "checkpoint-4", + print_outputs=True, + ) + assert returncode == 0, "Second run failed." + + # Case 2: Resuming from checkpoint by specifying a boolean, in this case it should look inside the output + # directory. + returncode, _ = runner.run( + num_cores, + precision, + train_batch_size, + eval_batch_size=eval_batch_size, + sequence_length=sequence_length, + tensor_parallel_size=tensor_parallel_size, + max_steps=max_steps + 10, # So that it makes more steps since we are restauring from the third run. + save_steps=save_steps, + do_eval=do_eval, + max_eval_samples=max_eval_samples, + output_dir=second_output_dir, + print_outputs=True, + ) + assert returncode == 0, "Third run failed." diff --git a/tests/test_examples.py b/tests/test_examples.py index bed75b4ec..41f0e3c65 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -378,6 +378,7 @@ def test(self): max_steps=self.MAX_STEPS, max_eval_samples=self.MAX_EVAL_SAMPLES, save_steps=self.SAVE_STEPS, + save_total_limit=1, learning_rate=self.LEARNING_RATE, tensor_parallel_size=tensor_parallel_size, disable_embedding_parallelization=disable_embedding_parallelization,