Skip to content

Commit

Permalink
rewrite src: add VLM, Kosmos, Flamingo
Browse files Browse the repository at this point in the history
  • Loading branch information
i-gao committed Sep 8, 2023
1 parent 99c350f commit 2f634f0
Show file tree
Hide file tree
Showing 9 changed files with 1,091 additions and 842 deletions.
1 change: 1 addition & 0 deletions open_flamingo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .src.flamingo import Flamingo
from .src.kosmos import Kosmos
from .src.factory import create_model_and_transforms
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch.nn as nn
import torch
from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive


class FlamingoLayer(nn.Module):
class DecoderLayerWithCrossAttention(nn.Module):
"""
FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
DecoderLayerWithCrossAttention is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
"""

def __init__(
Expand Down Expand Up @@ -33,17 +34,15 @@ def condition_vis_x(self, vis_x):
def condition_media_locations(self, media_locations):
self.media_locations = media_locations

def condition_use_cached_media(self, use_cached_media):
self.use_cached_media = use_cached_media

def forward(
self,
lang_x,
attention_mask=None,
**decoder_layer_kwargs,
):
# Cross attention
if self.gated_cross_attn_layer is not None:
contains_media = (self.media_locations == 1).any()
if contains_media and self.gated_cross_attn_layer is not None:
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")

Expand All @@ -56,7 +55,6 @@ def forward(
lang_x,
self.vis_x,
media_locations=self.media_locations,
use_cached_media=self.use_cached_media,
)

# Normal decoder layer
Expand All @@ -66,7 +64,7 @@ def forward(
return lang_x


class FlamingoLMMixin(nn.Module):
class CrossAttentionMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
Expand All @@ -80,16 +78,15 @@ def _get_decoder_layers(self):
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)

def init_flamingo(
def init_cross_attention_layers(
self,
media_token_id,
lang_hidden_size,
vis_hidden_size,
cross_attn_every_n_layers,
gradient_checkpointing,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
Add gated cross attn layers to the decoder.
"""
self.old_decoder_blocks = self._get_decoder_layers()
self.gated_cross_attn_layers = nn.ModuleList(
Expand All @@ -102,20 +99,10 @@ def init_flamingo(
for layer_idx, _ in enumerate(self._get_decoder_layers())
]
)
self.init_flamingo_layers(gradient_checkpointing)
self.media_token_id = media_token_id
self.initialized_flamingo = True
self._use_cached_vision_x = False

def init_flamingo_layers(self, gradient_checkpointing):
"""
Re initializes the FlamingoLayers.
Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks
"""
self._set_decoder_layers(
nn.ModuleList(
[
FlamingoLayer(
DecoderLayerWithCrossAttention(
gated_cross_attn_layer, decoder_layer, gradient_checkpointing
)
for gated_cross_attn_layer, decoder_layer in zip(
Expand All @@ -124,37 +111,44 @@ def init_flamingo_layers(self, gradient_checkpointing):
]
)
)
self.initialized_cross_attention = True

def forward(self, input_ids, attention_mask, **kwargs):
"""Condition the Flamingo layers on the media locations before forward()"""
if not self.initialized_flamingo:
raise ValueError(
"Flamingo layers are not initialized. Please call `init_flamingo` first."
def _condition_media_before_forward(
self,
input_ids: torch.Tensor,
vision_tokens: torch.Tensor = None,
past_media_locations: torch.Tensor = None,
past_vision_tokens: torch.Tensor = None,
):
"""Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence"""
assert (
self.initialized_cross_attention
), "Cross attention layers have not been initialized. "
if past_media_locations is not None and past_vision_tokens is not None:
if vision_tokens is not None:
updated_vision_tokens = torch.cat(
[
past_vision_tokens,
vision_tokens,
],
dim=1,
)
else:
updated_vision_tokens = past_vision_tokens
updated_media_locations = torch.cat(
[
past_media_locations,
input_ids == self.media_token_id,
],
dim=1,
)

media_locations = input_ids == self.media_token_id

# if there are media already cached and we're generating and there are no media tokens in the input,
# we'll assume that ALL input tokens should attend to the last previous media that is cached.
# this is especially important for HF generate() compatibility, since generate() calls forward()
# repeatedly one token at a time (with no media tokens).
# without this check, the model would not attend to any images when generating (after the first token)
use_cached_media_locations = (
self._use_cached_vision_x
and self.is_conditioned()
and not media_locations.any()
)
else:
updated_vision_tokens = vision_tokens
updated_media_locations = input_ids == self.media_token_id

for layer in self._get_decoder_layers():
if not use_cached_media_locations:
layer.condition_media_locations(media_locations)
layer.condition_use_cached_media(use_cached_media_locations)

# package arguments for the other parent's forward. since we don't know the order of the arguments,
# make them all kwargs
kwargs["input_ids"] = input_ids
kwargs["attention_mask"] = attention_mask
return super().forward(**kwargs) # Call the other parent's forward method
layer.condition_vis_x(updated_vision_tokens)
layer.condition_media_locations(updated_media_locations)

def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
Expand All @@ -164,4 +158,3 @@ def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
layer.condition_use_cached_media(None)
135 changes: 60 additions & 75 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,21 @@
import open_clip

from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
from .kosmos import Kosmos
from .utils import extend_instance
from .mllm import MLLM


def create_model_and_transforms(
model_family: str,
clip_vision_encoder_path: str,
clip_vision_encoder_pretrained: str,
lang_encoder_path: str,
lang_model_path: str,
tokenizer_path: str,
cross_attn_every_n_layers: int = 1,
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
freeze_lm_embeddings: bool = False,
model_family: str = "flamingo",
freeze_backbone_mllm: bool = False,
**flamingo_kwargs,
cache_dir: Optional[str] = None,
gradient_checkpointing: bool = False,
**model_kwargs,
):
"""
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
Expand All @@ -29,50 +27,50 @@ def create_model_and_transforms(
Args:
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
lang_encoder_path (str): path to pretrained language encoder
lang_model_path (str): path to pretrained language encoder
tokenizer_path (str): path to pretrained tokenizer
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Tokenizer: A tokenizer for the language model
"""

assert model_family in ("flamingo", "kosmos")

# load vision encoder
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path,
pretrained=clip_vision_encoder_pretrained,
cache_dir=cache_dir,
)
# set the vision encoder to output the visual features
vision_encoder.visual.output_tokens = True
vision_encoder = vision_encoder.visual

# load tokenizer and ensure there is a pad token
text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
)
# add Flamingo special tokens to the tokenizer
text_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endofchunk|>", "<image>"]}
)
if text_tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})

lang_encoder = AutoModelForCausalLM.from_pretrained(
lang_encoder_path,
# load langauge model
lang_model = AutoModelForCausalLM.from_pretrained(
lang_model_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
)

# hacks for MPT-1B, which doesn't have a get_input_embeddings method
if "mpt-1b-redpajama-200b" in lang_encoder_path:
## hacks for MPT-1B, which doesn't have a get_input_embeddings method
if "mpt-1b-redpajama-200b" in lang_model_path:

class EmbeddingFnMixin:
def get_input_embeddings(self):
Expand All @@ -81,69 +79,56 @@ def get_input_embeddings(self):
def set_input_embeddings(self, new_embeddings):
self.transformer.wte = new_embeddings

extend_instance(lang_encoder, EmbeddingFnMixin)
extend_instance(lang_model, EmbeddingFnMixin)

# init the model
if model_family == "flamingo":
# convert LM to FlamingoLM
extend_instance(lang_encoder, FlamingoLMMixin)

if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(
len(text_tokenizer)
)
if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_model)

model = Flamingo(
vision_encoder,
lang_encoder,
text_tokenizer.encode("<|endofchunk|>")[-1],
text_tokenizer.encode("<image>")[-1],
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
"width"
],
cross_attn_every_n_layers=cross_attn_every_n_layers,
**flamingo_kwargs,
)

# Freeze all parameters
model.requires_grad_(False)
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0

# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
model.perceiver.requires_grad_(True)
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)

# TODO: FIX this. Currently we are just training all embeddings unless freeze_lm_embeddings is on in which case we only train <image> and <eoc> embeddings
model.lang_encoder.get_input_embeddings().requires_grad_(True)
model.lang_encoder.get_output_embeddings().requires_grad_(True)

print(
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
vision_encoder=vision_encoder,
lang_model=lang_model,
vis_feature_dim=open_clip.get_model_config(clip_vision_encoder_path)[
"vision_cfg"
]["width"],
tokenizer_vocab_size=len(text_tokenizer),
gradient_checkpointing=gradient_checkpointing,
decoder_layers_attr_name=decoder_layers_attr_name,
pad_token=text_tokenizer.pad_token,
**model_kwargs,
)

elif model_family == "mllm":
lang_encoder.resize_token_embeddings(len(text_tokenizer))

model = MLLM(
vision_model=vision_encoder,
language_model=lang_encoder,
padding_token_id=text_tokenizer.pad_token_id,
media_token_id=text_tokenizer.encode("<image>")[-1],
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
"width"
],
elif model_family == "kosmos":
model = Kosmos(
vision_encoder=vision_encoder,
lang_model=lang_model,
vis_feature_dim=open_clip.get_model_config(clip_vision_encoder_path)[
"vision_cfg"
]["width"],
tokenizer_vocab_size=len(text_tokenizer),
gradient_checkpointing=gradient_checkpointing,
pad_token=text_tokenizer.pad_token,
**model_kwargs,
)

# Freeze vision encoder
model.vision_model.requires_grad_(False)
if freeze_backbone_mllm:
model.language_model.requires_grad_(False)

print(
f"MLLM model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
)
# add special tokens to the tokenizer and language models
text_tokenizer.add_special_tokens(
{"additional_special_tokens": list(model.special_tokens.values())}
)
model.lang_model.config.vocab_size = len(text_tokenizer)
model.set_special_token_ids(
{
v: text_tokenizer.convert_tokens_to_ids(v)
for v in model.special_tokens.values()
}
)

# freeze appropraite parameters
model.set_trainable()
print(
f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters"
)
return model, image_processor, text_tokenizer


Expand Down
Loading

0 comments on commit 2f634f0

Please sign in to comment.