Skip to content

Commit

Permalink
[WIP] Update trainer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Apr 16, 2024
1 parent 17cf147 commit 7ebf975
Showing 1 changed file with 132 additions and 52 deletions.
184 changes: 132 additions & 52 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import copy
import math
import os
import random
import shutil
import sys
import time
Expand All @@ -27,7 +26,7 @@
import numpy as np
import torch
from accelerate import __version__ as accelerate_version
from accelerate.utils import AutocastKwargs
from accelerate.utils import AutocastKwargs, DataLoaderConfiguration, GradientAccumulationPlugin
from packaging import version
from torch.utils.data import Dataset
from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments
Expand Down Expand Up @@ -61,7 +60,7 @@
speed_metrics,
)
from transformers.training_args import ParallelMode
from transformers.utils import WEIGHTS_NAME, is_apex_available, is_sagemaker_mp_enabled
from transformers.utils import WEIGHTS_NAME, is_accelerate_available, is_apex_available, is_sagemaker_mp_enabled

from ..utils import logging
from .accelerate import NeuronAccelerator, NeuronDistributedType
Expand Down Expand Up @@ -224,35 +223,95 @@ def prepare_args_for_precompilation(self, args: "TrainingArguments"):
args.do_predict = False

def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {}
if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs

# check if num_steps is attempted to be passed in gradient_accumulation_kwargs
if "num_steps" in grad_acc_kwargs and self.args.gradient_accumulation_steps > 1:
# raise because we do not know which setting is intended.
raise ValueError(
"The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
"If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
)
elif "num_steps" not in grad_acc_kwargs:
# take the gradient_accumulation_steps setting from TrainingArguments.
grad_acc_kwargs["num_steps"] = self.args.gradient_accumulation_steps

grad_acc_kwargs["sync_with_dataloader"] = False

gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

accelerator_config = self.args.accelerator_config.to_dict()

if is_accelerate_available("0.28.0"):
dataloader_config = DataLoaderConfiguration(
split_batches=accelerator_config.pop("split_batches"),
dispatch_batches=accelerator_config.pop("dispatch_batches"),
even_batches=accelerator_config.pop("even_batches"),
use_seedable_sampler=accelerator_config.pop("use_seedable_sampler"),
)
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")

args = {
"deepspeed_plugin": self.args.deepspeed_plugin,
"gradient_accumulation_plugin": gradient_accumulation_plugin,
}
if is_accelerate_available("0.28.0"):
args["dataloader_config"] = dataloader_config
else:
args.update(accelerator_config)

# create accelerator object
self.accelerator = NeuronAccelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
*args,
mp_plugin=self.args.mp_plugin,
zero_1=self.args.zero_1,
mixed_precision="bf16" if self.args.bf16 else "no",
autocast_backend=self.args.half_precision_backend,
)

# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics

# deepspeed and accelerate flags covering both trainer args and accelerate launcher
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None

# post accelerator creation setup
if self.is_fsdp_enabled:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
if is_accelerate_available("0.23.0"):
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
"activation_checkpointing", fsdp_plugin.activation_checkpointing
)
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
raise ValueError(
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
"when using FSDP."
)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
from transformers.deepspeed import HfTrainerDeepSpeedConfig
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
self.propagate_args_to_deepspeed()

ds_plugin = self.accelerator.state.deepspeed_plugin
# `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
if (
self.args.save_only_model
and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
and self.args.load_best_model_at_end
):
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")

ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)
# `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP
if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size:
wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.")

@requires_torch_neuronx
def synchronize_hub_cache(self):
Expand Down Expand Up @@ -283,6 +342,9 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
def _get_eval_sampler(self, eval_dataset: torch.utils.data.Dataset) -> Optional[torch.utils.data.Sampler]:
return torch.utils.data.SequentialSampler(eval_dataset)

def get_num_trainable_parameters(self):
return get_model_param_count(self.model, trainable_only=True)

@staticmethod
def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = transformers_get_optimizer_cls_and_kwargs(args)
Expand Down Expand Up @@ -382,7 +444,7 @@ def prediction_step(
return (loss, None, None)
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)

def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
logs: Dict[str, float] = {}

Expand Down Expand Up @@ -410,6 +472,9 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()

if grad_norm is not None:
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
Expand All @@ -419,17 +484,7 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for

metrics = None
if self.control.should_evaluate:
if isinstance(self.eval_dataset, dict):
metrics = {}
for eval_dataset_name, eval_dataset in self.eval_dataset.items():
dataset_metrics = self.evaluate(
eval_dataset=eval_dataset,
ignore_keys=ignore_keys_for_eval,
metric_key_prefix=f"eval_{eval_dataset_name}",
)
metrics.update(dataset_metrics)
else:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, self.state.global_step, metrics)

# Run delayed LR scheduler now that metrics are populated
Expand Down Expand Up @@ -471,7 +526,7 @@ def _save_xla(self, output_dir: Optional[str] = None):
Parallelizer.save_model_sharded_checkpoint(
self.model,
output_dir,
optimizer=self.optimizer,
optimizer=self.optimizer if not self.args.save_only_model else None,
use_xser=self.accelerator.state.mp_plugin.use_xser,
async_save=self.accelerator.state.mp_plugin.async_save,
num_local_ranks_per_step=self.accelerator.state.mp_plugin.num_local_ranks_per_step,
Expand Down Expand Up @@ -546,10 +601,18 @@ def _save_checkpoint(self, model, trial, metrics=None):
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)

with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
reissue_pt_warnings(caught_warnings)

if not self.args.save_only_model:
# Save RNG state
self._save_rng_state(output_dir)

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
Expand All @@ -570,24 +633,10 @@ def _save_checkpoint(self, model, trial, metrics=None):
if self.args.should_save:
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

# Save RNG state in non-distributed training
rng_states = {
"python": random.getstate(),
"numpy": np.random.get_state(),
"cpu": torch.random.get_rng_state(),
}

rng_states["xla"] = xm.get_rng_state()

# A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
# not yet exist.
os.makedirs(output_dir, exist_ok=True)

if self.args.world_size <= 1:
torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
else:
torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))

if self.args.push_to_hub:
self._push_from_checkpoint(output_dir)

Expand Down Expand Up @@ -624,6 +673,7 @@ def _inner_training_loop(
self._train_batch_size = batch_size
if is_main_worker():
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")

# Data loader and number of training steps
train_dataloader = self.get_train_dataloader()

Expand Down Expand Up @@ -797,6 +847,7 @@ def _inner_training_loop(
os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
):
self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
self.compare_trainer_and_checkpoint_args(self.args, self.state)
epochs_trained = self.state.global_step // num_update_steps_per_epoch
if not args.ignore_data_skip:
steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
Expand Down Expand Up @@ -843,13 +894,8 @@ def _inner_training_loop(
self._total_loss_scalar = 0.0
self._globalstep_last_logged = self.state.global_step

# It should be equivalent but prefer to use the `zero_grad` method from the optimizer when doing pipeline
# parallelism.
if isinstance(model, NxDPPModel):
self.optimizer.zero_grad()
else:
model.zero_grad()

self.optimizer.zero_grad()
grad_norm: Optional[float] = None
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
Expand Down Expand Up @@ -903,6 +949,23 @@ def _inner_training_loop(
step = -1
for step, inputs in enumerate(epoch_iterator):
total_batched_samples += 1

if self.args.include_num_input_tokens_seen:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
if main_input_name not in inputs:
logger.warning(
"Tried to track the number of tokens seen, however the current model is "
"not configured properly to know what item is the input. To fix this, add "
"a `main_input_name` attribute to the model class you are using."
)
else:
input_device = inputs[main_input_name].device
self.state.num_input_tokens_seen += torch.sum(
self.accelerator.gather(
torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
)
).item()

if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False
Expand Down Expand Up @@ -933,6 +996,11 @@ def _inner_training_loop(
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
if tr_loss.device != tr_loss_step.device:
raise ValueError(
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is "
f"{tr_loss_step.device}"
)
tr_loss += tr_loss_step

self.current_flos += float(self.floating_point_ops(inputs))
Expand Down Expand Up @@ -960,17 +1028,23 @@ def _inner_training_loop(

if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
elif self.use_apex:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
_grad_norm = torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer),
args.max_grad_norm,
)
else:
self.accelerator.clip_grad_norm_(
_grad_norm = self.accelerator.clip_grad_norm_(
model.parameters(),
args.max_grad_norm,
)
grad_norm = _grad_norm

# Optimizer step
self.optimizer.step()
Expand All @@ -987,11 +1061,16 @@ def _inner_training_loop(
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)

self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
else:
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

if self.control.should_epoch_stop or self.control.should_training_stop:
# PyTorch/XLA relies on the data loader to insert the mark_step for
# each step. Since we are breaking the loop early, we need to manually
# insert the mark_step here.
if is_torch_xla_available():
xm.mark_step()
break
if step < 0:
if is_main_worker():
Expand All @@ -1003,7 +1082,7 @@ def _inner_training_loop(
self.control.should_training_stop = True

self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
if is_torch_xla_available():
Expand Down Expand Up @@ -1036,7 +1115,8 @@ def _inner_training_loop(

# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step
effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
train_loss = self._total_loss_scalar / effective_global_step

metrics = speed_metrics(
"train",
Expand Down

0 comments on commit 7ebf975

Please sign in to comment.