From 7ebf9750566b062a89bf6a048d414011d1f371b4 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 16 Apr 2024 16:15:27 +0200 Subject: [PATCH] [WIP] Update trainer.py --- optimum/neuron/trainers.py | 184 ++++++++++++++++++++++++++----------- 1 file changed, 132 insertions(+), 52 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 346982491..774085033 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -17,7 +17,6 @@ import copy import math import os -import random import shutil import sys import time @@ -27,7 +26,7 @@ import numpy as np import torch from accelerate import __version__ as accelerate_version -from accelerate.utils import AutocastKwargs +from accelerate.utils import AutocastKwargs, DataLoaderConfiguration, GradientAccumulationPlugin from packaging import version from torch.utils.data import Dataset from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments @@ -61,7 +60,7 @@ speed_metrics, ) from transformers.training_args import ParallelMode -from transformers.utils import WEIGHTS_NAME, is_apex_available, is_sagemaker_mp_enabled +from transformers.utils import WEIGHTS_NAME, is_accelerate_available, is_apex_available, is_sagemaker_mp_enabled from ..utils import logging from .accelerate import NeuronAccelerator, NeuronDistributedType @@ -224,16 +223,58 @@ def prepare_args_for_precompilation(self, args: "TrainingArguments"): args.do_predict = False def create_accelerator_and_postprocess(self): + grad_acc_kwargs = {} + if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None: + grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs + + # check if num_steps is attempted to be passed in gradient_accumulation_kwargs + if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1: + # raise because we do not know which setting is intended. + raise ValueError( + "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`" + "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`." + ) + elif "num_steps" not in grad_acc_kwargs: + # take the gradient_accumulation_steps setting from TrainingArguments. + grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps + + grad_acc_kwargs["sync_with_dataloader"] = False + + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + + accelerator_config = self.args.accelerator_config.to_dict() + + if is_accelerate_available("0.28.0"): + dataloader_config = DataLoaderConfiguration( + split_batches=accelerator_config.pop("split_batches"), + dispatch_batches=accelerator_config.pop("dispatch_batches"), + even_batches=accelerator_config.pop("even_batches"), + use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"), + ) + # this would have been updated above, no need for it anymore + accelerator_config.pop("gradient_accumulation_kwargs") + + args = { + "deepspeed_plugin": self.args.deepspeed_plugin, + "gradient_accumulation_plugin": gradient_accumulation_plugin, + } + if is_accelerate_available("0.28.0"): + args["dataloader_config"] = dataloader_config + else: + args.update(accelerator_config) + # create accelerator object self.accelerator = NeuronAccelerator( - deepspeed_plugin=self.args.deepspeed_plugin, - gradient_accumulation_steps=self.args.gradient_accumulation_steps, + *args, mp_plugin=self.args.mp_plugin, zero_1=self.args.zero_1, mixed_precision="bf16" if self.args.bf16 else "no", autocast_backend=self.args.half_precision_backend, ) + # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag + self.gather_function = self.accelerator.gather_for_metrics + # deepspeed and accelerate flags covering both trainer args and accelerate launcher self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None @@ -241,18 +282,36 @@ def create_accelerator_and_postprocess(self): # post accelerator creation setup if self.is_fsdp_enabled: fsdp_plugin = self.accelerator.state.fsdp_plugin - fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False) - fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False) + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( + "limit_all_gathers", fsdp_plugin.limit_all_gathers + ) + if is_accelerate_available("0.23.0"): + fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( + "activation_checkpointing", fsdp_plugin.activation_checkpointing + ) + if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: + raise ValueError( + "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " + "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " + "when using FSDP." + ) - if self.is_deepspeed_enabled: - if getattr(self.args, "hf_deepspeed_config", None) is None: - from transformers.deepspeed import HfTrainerDeepSpeedConfig + if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: + self.propagate_args_to_deepspeed() - ds_plugin = self.accelerator.state.deepspeed_plugin + # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end` + if ( + self.args.save_only_model + and (self.is_deepspeed_enabled or self.is_fsdp_enabled) + and self.args.load_best_model_at_end + ): + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") - ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) - ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config - ds_plugin.hf_ds_config.trainer_config_process(self.args) + # `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP + if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size: + wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" + raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.") @requires_torch_neuronx def synchronize_hub_cache(self): @@ -283,6 +342,9 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_eval_sampler(self, eval_dataset: torch.utils.data.Dataset) -> Optional[torch.utils.data.Sampler]: return torch.utils.data.SequentialSampler(eval_dataset) + def get_num_trainable_parameters(self): + return get_model_param_count(self.model, trainable_only=True) + @staticmethod def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args) @@ -382,7 +444,7 @@ def prediction_step( return (loss, None, None) return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): if self.control.should_log and self.state.global_step > self._globalstep_last_logged: logs: Dict[str, float] = {} @@ -410,6 +472,9 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) logs["learning_rate"] = self._get_learning_rate() + if grad_norm is not None: + logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step self.store_flos() @@ -419,17 +484,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for metrics = None if self.control.should_evaluate: - if isinstance(self.eval_dataset, dict): - metrics = {} - for eval_dataset_name, eval_dataset in self.eval_dataset.items(): - dataset_metrics = self.evaluate( - eval_dataset=eval_dataset, - ignore_keys=ignore_keys_for_eval, - metric_key_prefix=f"eval_{eval_dataset_name}", - ) - metrics.update(dataset_metrics) - else: - metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) self._report_to_hp_search(trial, self.state.global_step, metrics) # Run delayed LR scheduler now that metrics are populated @@ -471,7 +526,7 @@ def _save_xla(self, output_dir: Optional[str] = None): Parallelizer.save_model_sharded_checkpoint( self.model, output_dir, - optimizer=self.optimizer, + optimizer=self.optimizer if not self.args.save_only_model else None, use_xser=self.accelerator.state.mp_plugin.use_xser, async_save=self.accelerator.state.mp_plugin.async_save, num_local_ranks_per_step=self.accelerator.state.mp_plugin.num_local_ranks_per_step, @@ -546,10 +601,18 @@ def _save_checkpoint(self, model, trial, metrics=None): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + if not self.args.save_only_model: + # Save optimizer and scheduler + self._save_optimizer_and_scheduler(output_dir) + with warnings.catch_warnings(record=True) as caught_warnings: xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) + if not self.args.save_only_model: + # Save RNG state + self._save_rng_state(output_dir) + # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model @@ -570,24 +633,10 @@ def _save_checkpoint(self, model, trial, metrics=None): if self.args.should_save: self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) - # Save RNG state in non-distributed training - rng_states = { - "python": random.getstate(), - "numpy": np.random.get_state(), - "cpu": torch.random.get_rng_state(), - } - - rng_states["xla"] = xm.get_rng_state() - # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # not yet exist. os.makedirs(output_dir, exist_ok=True) - if self.args.world_size <= 1: - torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) - else: - torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) - if self.args.push_to_hub: self._push_from_checkpoint(output_dir) @@ -624,6 +673,7 @@ def _inner_training_loop( self._train_batch_size = batch_size if is_main_worker(): logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -797,6 +847,7 @@ def _inner_training_loop( os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) ): self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + self.compare_trainer_and_checkpoint_args(self.args, self.state) epochs_trained = self.state.global_step // num_update_steps_per_epoch if not args.ignore_data_skip: steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) @@ -843,13 +894,8 @@ def _inner_training_loop( self._total_loss_scalar = 0.0 self._globalstep_last_logged = self.state.global_step - # It should be equivalent but prefer to use the `zero_grad` method from the optimizer when doing pipeline - # parallelism. - if isinstance(model, NxDPPModel): - self.optimizer.zero_grad() - else: - model.zero_grad() - + self.optimizer.zero_grad() + grad_norm: Optional[float] = None self.control = self.callback_handler.on_train_begin(args, self.state, self.control) # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. @@ -903,6 +949,23 @@ def _inner_training_loop( step = -1 for step, inputs in enumerate(epoch_iterator): total_batched_samples += 1 + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr(self.model, "main_input_name", "input_ids") + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + input_device = inputs[main_input_name].device + self.state.num_input_tokens_seen += torch.sum( + self.accelerator.gather( + torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64) + ) + ).item() + if rng_to_sync: self._load_rng_state(resume_from_checkpoint) rng_to_sync = False @@ -933,6 +996,11 @@ def _inner_training_loop( # if loss is nan or inf simply add the average of previous logged losses tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is " + f"{tr_loss_step.device}" + ) tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs)) @@ -960,17 +1028,23 @@ def _inner_training_loop( if is_sagemaker_mp_enabled() and args.fp16: self.optimizer.clip_master_grads(args.max_grad_norm) + _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm) elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm, ) + _grad_norm = torch.nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) else: - self.accelerator.clip_grad_norm_( + _grad_norm = self.accelerator.clip_grad_norm_( model.parameters(), args.max_grad_norm, ) + grad_norm = _grad_norm # Optimizer step self.optimizer.step() @@ -987,11 +1061,16 @@ def _inner_training_loop( self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) if self.control.should_epoch_stop or self.control.should_training_stop: + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if is_torch_xla_available(): + xm.mark_step() break if step < 0: if is_main_worker(): @@ -1003,7 +1082,7 @@ def _inner_training_loop( self.control.should_training_stop = True self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) if DebugOption.TPU_METRICS_DEBUG in self.args.debug: if is_torch_xla_available(): @@ -1036,7 +1115,8 @@ def _inner_training_loop( # add remaining tr_loss self._total_loss_scalar += tr_loss.item() - train_loss = self._total_loss_scalar / self.state.global_step + effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step metrics = speed_metrics( "train",