Skip to content

Commit

Permalink
remove z-loss mess
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla committed Feb 24, 2024
1 parent feba465 commit 0b1c926
Showing 1 changed file with 6 additions and 77 deletions.
83 changes: 6 additions & 77 deletions open_flamingo/train/losses.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from open_flamingo.src.vlm import VLM
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss

SUPPORTED_LOSSES = ["next_token_prediction", "next_token_prediction_with_z_loss"]
SUPPORTED_LOSSES = ["next_token_prediction"]


def get_loss_fn(loss_name):
if loss_name == "next_token_prediction":
return NextTokenPrediction()
elif loss_name == "next_token_prediction_with_z_loss":
return NextTokenPredictionWithZLoss()
else:
raise ValueError(
f"Loss {loss_name} not supported. Supported losses: {SUPPORTED_LOSSES}"
Expand Down Expand Up @@ -47,10 +43,10 @@ def __call__(
raise NotImplementedError


class NextTokenPredictionWithZLoss(Loss):
class NextTokenPrediction(Loss):
@property
def name(self):
return "next_token_prediction_with_z_loss"
return "next_token_prediction"

def __call__(
self,
Expand All @@ -60,7 +56,6 @@ def __call__(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
autocast: callable,
z_loss_eps: float = 1e-4,
):
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
Expand All @@ -74,55 +69,15 @@ def __call__(

# call forward
with autocast():
logits = model(
loss = model(
vision_x=images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels,
)[1]

logits = logits.float()

# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLossWithZLoss(eps=z_loss_eps)
shift_logits = shift_logits.view(-1, unwrap_model(model).lang_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

)[0]
return loss


class NextTokenPrediction(NextTokenPredictionWithZLoss):
# same as NextTokenPredictionWithZLoss, but with z_loss_eps = 0
@property
def name(self):
return "next_token_prediction"

def __call__(
self,
model: VLM,
tokenizer,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
autocast: callable,
):
return super().__call__(
model=model,
tokenizer=tokenizer,
images=images,
input_ids=input_ids,
attention_mask=attention_mask,
autocast=autocast,
z_loss_eps=0,
)


def unwrap_model(model):
"""
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
Expand All @@ -132,30 +87,4 @@ def unwrap_model(model):
):
return model.module
else:
return model


# From OpenLM (https://github.com/mlfoundations/open_lm/blob/main/open_lm/losses.py)
class CrossEntropyLossWithZLoss(CrossEntropyLoss):
def __init__(
self,
eps: float = 1e-4,
weight: Tensor = None,
size_average=None,
ignore_index: int = -100,
reduce=None,
reduction: str = "mean",
label_smoothing: float = 0,
) -> None:
super().__init__(
weight, size_average, ignore_index, reduce, reduction, label_smoothing
)
self.eps = eps

def forward(self, input: Tensor, target: Tensor) -> Tensor:
if self.eps == 0:
return super().forward(input, target)

return super().forward(input, target) + self.eps * torch.square(
torch.logsumexp(input, dim=-1).mean()
)
return model

0 comments on commit 0b1c926

Please sign in to comment.