Skip to content

Commit

Permalink
Fix logging of metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Oct 17, 2024
1 parent b6fb211 commit ad4e480
Showing 1 changed file with 80 additions and 2 deletions.
82 changes: 80 additions & 2 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from accelerate.state import PartialState
from accelerate.utils import AutocastKwargs, DataLoaderConfiguration, GradientAccumulationPlugin
from packaging import version
from torch import nn
from torch.utils.data import Dataset
from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -1873,6 +1874,83 @@ def tokenize(element):

# class NeuronORPOTrainer(ORPOTrainer):
class NeuronORPOTrainer(_TrainerForNeuron, ORPOTrainer):
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
concatenated_batch = self.concatenated_inputs(
batch,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
padding_value=self.padding_value,
device=self.accelerator.device,
)
len_chosen = batch["chosen_labels"].shape[0]

model_kwargs = (
{
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
}
if self.is_encoder_decoder
else {}
)

if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True

outputs = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
)
all_logits = outputs.logits

def cross_entropy_loss(logits, labels):
if not self.is_encoder_decoder:
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = logits.view(-1, logits.shape[-1])
labels = labels.view(-1)
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
return loss

if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)

chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

# It is important to mark the step here to materialize the graph and tensors otherwise the compiler fails in
# `get_batch_loss_metrics` when adding `policy_rejected_logits` and `policy_chosen_logits` to the `metrics`.
xm.mark_step()

if self.aux_loss_enabled:
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)

def get_batch_loss_metrics(
self,
model,
Expand Down Expand Up @@ -1907,8 +1985,8 @@ def get_batch_loss_metrics(
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean()
# metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
# metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean()
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
Expand Down

0 comments on commit ad4e480

Please sign in to comment.