From 2f634f0740b83fb7cccabfe97891f064858deb91 Mon Sep 17 00:00:00 2001 From: i-gao Date: Thu, 7 Sep 2023 14:51:31 -0700 Subject: [PATCH] rewrite src: add VLM, Kosmos, Flamingo --- open_flamingo/__init__.py | 1 + .../src/{flamingo_lm.py => cross_attn_lm.py} | 93 ++- open_flamingo/src/factory.py | 135 ++-- open_flamingo/src/flamingo.py | 352 ++--------- open_flamingo/src/helpers.py | 299 ++++++++- open_flamingo/src/kosmos.py | 46 ++ open_flamingo/src/mllm.py | 372 ----------- open_flamingo/src/utils.py | 40 ++ open_flamingo/src/vlm.py | 595 ++++++++++++++++++ 9 files changed, 1091 insertions(+), 842 deletions(-) rename open_flamingo/src/{flamingo_lm.py => cross_attn_lm.py} (59%) create mode 100644 open_flamingo/src/kosmos.py delete mode 100644 open_flamingo/src/mllm.py create mode 100644 open_flamingo/src/vlm.py diff --git a/open_flamingo/__init__.py b/open_flamingo/__init__.py index ab67750b..3455ccbf 100644 --- a/open_flamingo/__init__.py +++ b/open_flamingo/__init__.py @@ -1,2 +1,3 @@ from .src.flamingo import Flamingo +from .src.kosmos import Kosmos from .src.factory import create_model_and_transforms diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/cross_attn_lm.py similarity index 59% rename from open_flamingo/src/flamingo_lm.py rename to open_flamingo/src/cross_attn_lm.py index c4933e9d..54b15e5a 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/cross_attn_lm.py @@ -1,11 +1,12 @@ import torch.nn as nn +import torch from .helpers import GatedCrossAttentionBlock from .utils import getattr_recursive, setattr_recursive -class FlamingoLayer(nn.Module): +class DecoderLayerWithCrossAttention(nn.Module): """ - FlamingoLayer is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. + DecoderLayerWithCrossAttention is a wrapper around the GatedCrossAttentionBlock and DecoderLayer. """ def __init__( @@ -33,9 +34,6 @@ def condition_vis_x(self, vis_x): def condition_media_locations(self, media_locations): self.media_locations = media_locations - def condition_use_cached_media(self, use_cached_media): - self.use_cached_media = use_cached_media - def forward( self, lang_x, @@ -43,7 +41,8 @@ def forward( **decoder_layer_kwargs, ): # Cross attention - if self.gated_cross_attn_layer is not None: + contains_media = (self.media_locations == 1).any() + if contains_media and self.gated_cross_attn_layer is not None: if self.vis_x is None: raise ValueError("vis_x must be conditioned before forward pass") @@ -56,7 +55,6 @@ def forward( lang_x, self.vis_x, media_locations=self.media_locations, - use_cached_media=self.use_cached_media, ) # Normal decoder layer @@ -66,7 +64,7 @@ def forward( return lang_x -class FlamingoLMMixin(nn.Module): +class CrossAttentionMixin(nn.Module): """ Mixin to add cross-attention layers to a language model. """ @@ -80,16 +78,15 @@ def _get_decoder_layers(self): def _set_decoder_layers(self, value): setattr_recursive(self, self.decoder_layers_attr_name, value) - def init_flamingo( + def init_cross_attention_layers( self, - media_token_id, lang_hidden_size, vis_hidden_size, cross_attn_every_n_layers, gradient_checkpointing, ): """ - Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. + Add gated cross attn layers to the decoder. """ self.old_decoder_blocks = self._get_decoder_layers() self.gated_cross_attn_layers = nn.ModuleList( @@ -102,20 +99,10 @@ def init_flamingo( for layer_idx, _ in enumerate(self._get_decoder_layers()) ] ) - self.init_flamingo_layers(gradient_checkpointing) - self.media_token_id = media_token_id - self.initialized_flamingo = True - self._use_cached_vision_x = False - - def init_flamingo_layers(self, gradient_checkpointing): - """ - Re initializes the FlamingoLayers. - Propagates any changes made to self.gated_corss_attn_layers or self.old_decoder_blocks - """ self._set_decoder_layers( nn.ModuleList( [ - FlamingoLayer( + DecoderLayerWithCrossAttention( gated_cross_attn_layer, decoder_layer, gradient_checkpointing ) for gated_cross_attn_layer, decoder_layer in zip( @@ -124,37 +111,44 @@ def init_flamingo_layers(self, gradient_checkpointing): ] ) ) + self.initialized_cross_attention = True - def forward(self, input_ids, attention_mask, **kwargs): - """Condition the Flamingo layers on the media locations before forward()""" - if not self.initialized_flamingo: - raise ValueError( - "Flamingo layers are not initialized. Please call `init_flamingo` first." + def _condition_media_before_forward( + self, + input_ids: torch.Tensor, + vision_tokens: torch.Tensor = None, + past_media_locations: torch.Tensor = None, + past_vision_tokens: torch.Tensor = None, + ): + """Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence""" + assert ( + self.initialized_cross_attention + ), "Cross attention layers have not been initialized. " + if past_media_locations is not None and past_vision_tokens is not None: + if vision_tokens is not None: + updated_vision_tokens = torch.cat( + [ + past_vision_tokens, + vision_tokens, + ], + dim=1, + ) + else: + updated_vision_tokens = past_vision_tokens + updated_media_locations = torch.cat( + [ + past_media_locations, + input_ids == self.media_token_id, + ], + dim=1, ) - - media_locations = input_ids == self.media_token_id - - # if there are media already cached and we're generating and there are no media tokens in the input, - # we'll assume that ALL input tokens should attend to the last previous media that is cached. - # this is especially important for HF generate() compatibility, since generate() calls forward() - # repeatedly one token at a time (with no media tokens). - # without this check, the model would not attend to any images when generating (after the first token) - use_cached_media_locations = ( - self._use_cached_vision_x - and self.is_conditioned() - and not media_locations.any() - ) + else: + updated_vision_tokens = vision_tokens + updated_media_locations = input_ids == self.media_token_id for layer in self._get_decoder_layers(): - if not use_cached_media_locations: - layer.condition_media_locations(media_locations) - layer.condition_use_cached_media(use_cached_media_locations) - - # package arguments for the other parent's forward. since we don't know the order of the arguments, - # make them all kwargs - kwargs["input_ids"] = input_ids - kwargs["attention_mask"] = attention_mask - return super().forward(**kwargs) # Call the other parent's forward method + layer.condition_vis_x(updated_vision_tokens) + layer.condition_media_locations(updated_media_locations) def is_conditioned(self) -> bool: """Check whether all decoder layers are already conditioned.""" @@ -164,4 +158,3 @@ def clear_conditioned_layers(self): for layer in self._get_decoder_layers(): layer.condition_vis_x(None) layer.condition_media_locations(None) - layer.condition_use_cached_media(None) diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 24d198bb..985e6482 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -4,23 +4,21 @@ import open_clip from .flamingo import Flamingo -from .flamingo_lm import FlamingoLMMixin +from .kosmos import Kosmos from .utils import extend_instance -from .mllm import MLLM def create_model_and_transforms( + model_family: str, clip_vision_encoder_path: str, clip_vision_encoder_pretrained: str, - lang_encoder_path: str, + lang_model_path: str, tokenizer_path: str, - cross_attn_every_n_layers: int = 1, use_local_files: bool = False, decoder_layers_attr_name: str = None, - freeze_lm_embeddings: bool = False, - model_family: str = "flamingo", - freeze_backbone_mllm: bool = False, - **flamingo_kwargs, + cache_dir: Optional[str] = None, + gradient_checkpointing: bool = False, + **model_kwargs, ): """ Initialize a Flamingo model from a pretrained vision encoder and language encoder. @@ -29,50 +27,50 @@ def create_model_and_transforms( Args: clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") - lang_encoder_path (str): path to pretrained language encoder + lang_model_path (str): path to pretrained language encoder tokenizer_path (str): path to pretrained tokenizer cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. use_local_files (bool, optional): whether to use local files. Defaults to False. decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. - freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. + gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False. Returns: Flamingo: Flamingo model from pretrained vision and language encoders Image processor: Pipeline to preprocess input images Tokenizer: A tokenizer for the language model """ + + assert model_family in ("flamingo", "kosmos") + + # load vision encoder vision_encoder, _, image_processor = open_clip.create_model_and_transforms( clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained, cache_dir=cache_dir, ) - # set the vision encoder to output the visual features vision_encoder.visual.output_tokens = True + vision_encoder = vision_encoder.visual + # load tokenizer and ensure there is a pad token text_tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, local_files_only=use_local_files, trust_remote_code=True, cache_dir=cache_dir, ) - # add Flamingo special tokens to the tokenizer - text_tokenizer.add_special_tokens( - {"additional_special_tokens": ["<|endofchunk|>", ""]} - ) if text_tokenizer.pad_token is None: - # Issue: GPT models don't have a pad token, which we use to - # modify labels for the loss. text_tokenizer.add_special_tokens({"pad_token": ""}) - lang_encoder = AutoModelForCausalLM.from_pretrained( - lang_encoder_path, + # load langauge model + lang_model = AutoModelForCausalLM.from_pretrained( + lang_model_path, local_files_only=use_local_files, trust_remote_code=True, cache_dir=cache_dir, ) - # hacks for MPT-1B, which doesn't have a get_input_embeddings method - if "mpt-1b-redpajama-200b" in lang_encoder_path: + ## hacks for MPT-1B, which doesn't have a get_input_embeddings method + if "mpt-1b-redpajama-200b" in lang_model_path: class EmbeddingFnMixin: def get_input_embeddings(self): @@ -81,69 +79,56 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.transformer.wte = new_embeddings - extend_instance(lang_encoder, EmbeddingFnMixin) + extend_instance(lang_model, EmbeddingFnMixin) + # init the model if model_family == "flamingo": - # convert LM to FlamingoLM - extend_instance(lang_encoder, FlamingoLMMixin) - - if decoder_layers_attr_name is None: - decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) - lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) - lang_encoder.resize_token_embeddings( - len(text_tokenizer) - ) + if decoder_layers_attr_name is None: + decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_model) model = Flamingo( - vision_encoder, - lang_encoder, - text_tokenizer.encode("<|endofchunk|>")[-1], - text_tokenizer.encode("")[-1], - vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ - "width" - ], - cross_attn_every_n_layers=cross_attn_every_n_layers, - **flamingo_kwargs, - ) - - # Freeze all parameters - model.requires_grad_(False) - assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 - - # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings - model.perceiver.requires_grad_(True) - model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) - - # TODO: FIX this. Currently we are just training all embeddings unless freeze_lm_embeddings is on in which case we only train and embeddings - model.lang_encoder.get_input_embeddings().requires_grad_(True) - model.lang_encoder.get_output_embeddings().requires_grad_(True) - - print( - f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" + vision_encoder=vision_encoder, + lang_model=lang_model, + vis_feature_dim=open_clip.get_model_config(clip_vision_encoder_path)[ + "vision_cfg" + ]["width"], + tokenizer_vocab_size=len(text_tokenizer), + gradient_checkpointing=gradient_checkpointing, + decoder_layers_attr_name=decoder_layers_attr_name, + pad_token=text_tokenizer.pad_token, + **model_kwargs, ) - elif model_family == "mllm": - lang_encoder.resize_token_embeddings(len(text_tokenizer)) - - model = MLLM( - vision_model=vision_encoder, - language_model=lang_encoder, - padding_token_id=text_tokenizer.pad_token_id, - media_token_id=text_tokenizer.encode("")[-1], - vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ - "width" - ], + elif model_family == "kosmos": + model = Kosmos( + vision_encoder=vision_encoder, + lang_model=lang_model, + vis_feature_dim=open_clip.get_model_config(clip_vision_encoder_path)[ + "vision_cfg" + ]["width"], + tokenizer_vocab_size=len(text_tokenizer), + gradient_checkpointing=gradient_checkpointing, + pad_token=text_tokenizer.pad_token, + **model_kwargs, ) - # Freeze vision encoder - model.vision_model.requires_grad_(False) - if freeze_backbone_mllm: - model.language_model.requires_grad_(False) - - print( - f"MLLM model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" - ) + # add special tokens to the tokenizer and language models + text_tokenizer.add_special_tokens( + {"additional_special_tokens": list(model.special_tokens.values())} + ) + model.lang_model.config.vocab_size = len(text_tokenizer) + model.set_special_token_ids( + { + v: text_tokenizer.convert_tokens_to_ids(v) + for v in model.special_tokens.values() + } + ) + # freeze appropraite parameters + model.set_trainable() + print( + f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters" + ) return model, image_processor, text_tokenizer diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index 4acfabd2..196c3953 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -1,337 +1,57 @@ -import torch -from einops import rearrange from torch import nn from .helpers import PerceiverResampler -from torch.distributed.fsdp.wrap import ( - enable_wrap, - wrap, -) -from transformers.modeling_outputs import CausalLMOutputWithPast -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, -) +from .vlm import VLMWithCrossAttention -from .utils import apply_with_stopping_condition - -class Flamingo(nn.Module): +class Flamingo(VLMWithCrossAttention): def __init__( self, vision_encoder: nn.Module, - lang_encoder: nn.Module, - eoc_token_id: int, - media_token_id: int, - vis_dim: int, + lang_model: nn.Module, + vis_feature_dim: int, + tokenizer_vocab_size: int, + pad_token: str, cross_attn_every_n_layers: int = 1, + decoder_layers_attr_name: str = None, gradient_checkpointing: bool = False, ): """ Args: vision_encoder (nn.Module): HF CLIPModel - lang_encoder (nn.Module): HF causal language model - eoc_token_id (int): Token id for <|endofchunk|> - media_token_id (int): Token id for - vis_dim (int): Dimension of the visual features. - Visual features are projected to match this shape along the last dimension. + lang_model (nn.Module): HF causal language model + vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder + tokenizer_vocab_size (int): size of the tokenizer vocab cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. - """ - super().__init__() - self.eoc_token_id = eoc_token_id - self.media_token_id = media_token_id - self.vis_dim = vis_dim - if hasattr(lang_encoder.config, "d_model"): - self.lang_dim = lang_encoder.config.d_model # mpt uses d_model - else: - self.lang_dim = lang_encoder.config.hidden_size - - self.vision_encoder = vision_encoder.visual - self.perceiver = PerceiverResampler(dim=self.vis_dim) - self.lang_encoder = lang_encoder - self.lang_encoder.init_flamingo( - media_token_id=media_token_id, - lang_hidden_size=self.lang_dim, - vis_hidden_size=self.vis_dim, - cross_attn_every_n_layers=cross_attn_every_n_layers, + decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. + gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False. + """ + self._special_tokens = { + "eoc_token": "<|endofchunk|>", + "media_token": "", + "pad_token": pad_token, + } + super().__init__( + vision_encoder=vision_encoder, + vision_tokenizer=PerceiverResampler(dim=vis_feature_dim), + lang_model=lang_model, gradient_checkpointing=gradient_checkpointing, + tokenizer_vocab_size=tokenizer_vocab_size, + cross_attn_every_n_layers=cross_attn_every_n_layers, + decoder_layers_attr_name=decoder_layers_attr_name, ) - self._use_gradient_checkpointing = gradient_checkpointing - self.perceiver._use_gradient_checkpointing = gradient_checkpointing - def forward( - self, - vision_x: torch.Tensor, - lang_x: torch.Tensor, - attention_mask: torch.Tensor = None, - labels: torch.Tensor = None, - clear_conditioned_layers: bool = True, - past_key_values=None, - use_cache: bool = False, - ): + def set_trainable(self): """ - Forward pass of Flamingo. - - Args: - vision_x (torch.Tensor): Vision input - shape (B, T_img, F, C, H, W) with F=1 - lang_x (torch.Tensor): Language input ids - shape (B, T_txt) - attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. - labels (torch.Tensor, optional): Labels. Defaults to None. - clear_conditioned_layers: if True, clear the conditioned layers - once the foward pass is completed. Set this to false if the - same set of images will be reused in another subsequent - forward pass. - past_key_values: pre-computed values to pass to language model. - See past_key_values documentation in Hugging Face - CausalLM models. - use_cache: whether to use cached key values. See use_cache - documentation in Hugging Face CausalLM models. + Freeze everything except: perceiver, gated_cross_attn_layers, and inserted LM input embeddings """ - assert ( - self.lang_encoder.initialized_flamingo - ), "Flamingo layers are not initialized. Please call `init_flamingo` first." - - assert ( - self.lang_encoder._use_cached_vision_x or vision_x is not None - ), "Must provide either vision_x or have precached media using cache_media()." - - if self.lang_encoder._use_cached_vision_x: - # Case: use cached; vision_x should be cached and other - # vision-related inputs should not be provided. - assert ( - vision_x is None - ), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first." - assert self.lang_encoder.is_conditioned() - - else: - # Case: do not use caching (i.e. this is a standard forward pass); - self._encode_vision_x(vision_x=vision_x) - self._condition_media_locations(input_ids=lang_x) - - output = self.lang_encoder( - input_ids=lang_x, - attention_mask=attention_mask, - labels=labels, - past_key_values=past_key_values, - use_cache=use_cache, + self.requires_grad_(False) + self.vision_tokenizer.requires_grad_(True) + self.lang_model.gated_cross_attn_layers.requires_grad_(True) + self.lang_model.get_output_embeddings().set_requires_grad( + require_regular_grad=False, + require_additional_grad=True, ) - - if clear_conditioned_layers: - self.lang_encoder.clear_conditioned_layers() - - return output - - def generate( - self, - vision_x: torch.Tensor, - lang_x: torch.Tensor, - attention_mask: torch.Tensor = None, - **kwargs, - ): - """ - Generate text conditioned on vision and language inputs. - - Args: - vision_x (torch.Tensor): Vision input - shape (B, T_img, F, C, H, W) - images in the same chunk are collated along T_img, and frames are collated along F - currently only F=1 is supported (single-frame videos) - lang_x (torch.Tensor): Language input - shape (B, T_txt) - **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs: - max_length (int, optional): Maximum length of the output. Defaults to None. - attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. - num_beams (int, optional): Number of beams. Defaults to 1. - max_new_tokens (int, optional): Maximum new tokens. Defaults to None. - temperature (float, optional): Temperature. Defaults to 1.0. - top_k (int, optional): Top k. Defaults to 50. - top_p (float, optional): Top p. Defaults to 1.0. - no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. - length_penalty (float, optional): Length penalty. Defaults to 1.0. - num_return_sequences (int, optional): Number of return sequences. Defaults to 1. - do_sample (bool, optional): Do sample. Defaults to False. - early_stopping (bool, optional): Early stopping. Defaults to False. - Returns: - torch.Tensor: lang_x with generated tokens appended to it - """ - num_beams = kwargs.pop("num_beams", 1) - if num_beams > 1: - vision_x = vision_x.repeat_interleave(num_beams, dim=0) - - self.lang_encoder._use_cached_vision_x = True - self._encode_vision_x(vision_x=vision_x) - - eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id) - output = self.lang_encoder.generate( - input_ids=lang_x, - attention_mask=attention_mask, - eos_token_id=eos_token_id, - num_beams=num_beams, - **kwargs, + self.lang_model.get_input_embeddings().set_requires_grad( + require_regular_grad=False, + require_additional_grad=True, ) - - self.lang_encoder.clear_conditioned_layers() - self.lang_encoder._use_cached_vision_x = False - return output - - def _encode_vision_x(self, vision_x: torch.Tensor): - """ - Compute media tokens from vision input by passing it through vision encoder and conditioning language model. - Args: - vision_x (torch.Tensor): Vision input - shape (B, T_img, F, C, H, W) - Images in the same chunk are collated along T_img, and frames are collated along F - Currently only F=1 is supported (single-frame videos) - - rearrange code based on https://github.com/dhansmair/flamingo-mini - """ - - assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" - b, T, F = vision_x.shape[:3] - - vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") - with torch.no_grad(): - vision_x = self.vision_encoder(vision_x)[1] - vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) - vision_x = self.perceiver(vision_x) - - for layer in self.lang_encoder._get_decoder_layers(): - layer.condition_vis_x(vision_x) - - def wrap_fsdp(self, wrapper_kwargs, device_id): - """ - Manually wraps submodules for FSDP and move other parameters to device_id. - - Why manually wrap? - - all parameters within the FSDP wrapper must have the same requires_grad. - We have a mix of frozen and unfrozen parameters. - - model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors - See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344 - - The rough wrapping structure is: - - FlamingoModel - - FSDP(FSDP(vision_encoder)) - - FSDP(FSDP(perceiver)) - - lang_encoder - - FSDP(FSDP(input_embeddings)) - - FlamingoLayers - - FSDP(FSDP(gated_cross_attn_layer)) - - FSDP(FSDP(decoder_layer)) - - FSDP(FSDP(output_embeddings)) - - other parameters - - Known issues: - - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied, - train with DDP or set the --freeze_lm_embeddings flag to true. - - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound. - Although the training curves look okay, we found that downstream performance dramatically - degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M). - - FAQs about our FSDP wrapping strategy: - Why double wrap? - As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook - only free gathered parameters if the module is NOT FSDP root. - - Why unfreeze the decoder_layers? - See https://github.com/pytorch/pytorch/issues/95805 - As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param - requires_grad=True. We need the postback to fire to avoid OOM. - To effectively freeze the decoder layers, we exclude them from the optimizer. - - What is assumed to be frozen v. unfrozen? - We assume that the model is being trained under normal Flamingo settings - with these lines being called in factory.py: - ``` - # Freeze all parameters - model.requires_grad_(False) - assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 - - # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings - model.perceiver.requires_grad_(True) - model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) - [optional] model.lang_encoder.get_input_embeddings().requires_grad_(True) - ``` - """ - # unfreeze the decoder layers - for block in self.lang_encoder.old_decoder_blocks: - block.requires_grad_(True) - - # wrap in FSDP - with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): - self.perceiver = wrap(wrap(self.perceiver)) - self.lang_encoder.old_decoder_blocks = nn.ModuleList( - wrap(wrap(block)) for block in self.lang_encoder.old_decoder_blocks - ) - self.lang_encoder.gated_cross_attn_layers = nn.ModuleList( - wrap(wrap(layer)) if layer is not None else None - for layer in self.lang_encoder.gated_cross_attn_layers - ) - self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) - self.lang_encoder.set_input_embeddings( - wrap(wrap(self.lang_encoder.get_input_embeddings())) - ) - self.lang_encoder.set_output_embeddings( - wrap(wrap(self.lang_encoder.get_output_embeddings())) - ) - self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen - - # manually move non-FSDP managed parameters to device_id - # these are all in lang_encoder - apply_with_stopping_condition( - module=self.lang_encoder, - apply_fn=lambda m: m.to(device_id), - apply_condition=lambda m: len(list(m.children())) == 0, - stopping_condition=lambda m: isinstance(m, FSDP), - ) - - # exclude the original decoder layers from the optimizer - for block in self.lang_encoder.old_decoder_blocks: - for p in block.parameters(): - p.exclude_from_optimizer = True - - # set up clip_grad_norm_ function - def clip_grad_norm_(max_norm): - self.perceiver.clip_grad_norm_(max_norm) - for layer in self.lang_encoder.gated_cross_attn_layers: - if layer is not None: - layer.clip_grad_norm_(max_norm) - self.lang_encoder.get_input_embeddings().clip_grad_norm_(max_norm) - - self.clip_grad_norm_ = clip_grad_norm_ - - def _condition_media_locations(self, input_ids: torch.Tensor): - """ - Compute the media token locations from lang_x and condition the language model on these. - Args: - input_ids (torch.Tensor): Language input - shape (B, T_txt) - """ - media_locations = input_ids == self.media_token_id - - for layer in self.lang_encoder._get_decoder_layers(): - layer.condition_media_locations(media_locations) - - def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor): - """ - Pre-cache a prompt/sequence of images / text for log-likelihood evaluations. - All subsequent calls to forward() will generate attending to the LAST - image in vision_x. - This is not meant to be used to cache things for generate(). - Args: - input_ids (torch.Tensor): Language input - shape (B, T_txt) - vision_x (torch.Tensor): Vision input - shape (B, T_img, F, C, H, W) - Images in the same chunk are collated along T_img, and frames are collated along F - Currently only F=1 is supported (single-frame videos) - """ - self._encode_vision_x(vision_x=vision_x) - self._condition_media_locations(input_ids=input_ids) - self.lang_encoder._use_cached_vision_x = True - - def uncache_media(self): - """ - Clear all conditioning. - """ - self.lang_encoder.clear_conditioned_layers() - self.lang_encoder._use_cached_vision_x = False diff --git a/open_flamingo/src/helpers.py b/open_flamingo/src/helpers.py index 239503f8..fb809cf8 100644 --- a/open_flamingo/src/helpers.py +++ b/open_flamingo/src/helpers.py @@ -3,9 +3,25 @@ """ import torch +import torch.nn.functional as F from einops import rearrange, repeat from einops_exts import rearrange_many from torch import einsum, nn +from transformers.modeling_outputs import CausalLMOutputWithPast +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class VLMOutputWithPast(CausalLMOutputWithPast): + """ + VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes: + past_media_locations: Optional[torch.Tensor] = None, + past_vision_tokens: Optional[torch.Tensor] = None, + """ + + past_media_locations: Optional[torch.Tensor] = None + past_vision_tokens: Optional[torch.Tensor] = None def exists(val): @@ -70,6 +86,7 @@ def __init__( self, *, dim, + dim_inner=None, depth=6, dim_head=64, heads=8, @@ -78,15 +95,39 @@ def __init__( max_num_frames=None, ff_mult=4, ): + """ + Perceiver module which takes in image features and outputs image tokens. + Args: + dim (int): final dimension of the incoming image features + dim_inner (int, optional): final dimension to project the incoming image features to; + also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim. + depth (int, optional): number of layers. Defaults to 6. + dim_head (int, optional): dimension of each head. Defaults to 64. + heads (int, optional): number of heads. Defaults to 8. + num_latents (int, optional): number of latent tokens to use in the Perceiver; + also corresponds to number of tokens per sequence to output. Defaults to 64. + max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver + and keep positional embeddings for. If None, no positional embeddings are used. + max_num_frames (int, optional): maximum number of frames to input into the Perceiver + and keep positional embeddings for. If None, no positional embeddings are used. + ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4. + """ super().__init__() - self.latents = nn.Parameter(torch.randn(num_latents, dim)) + if dim_inner is not None: + self.projection = nn.Linear(dim, dim_inner) + else: + self.projection = None + dim_inner = dim + + self.latents = nn.Parameter(torch.randn(num_latents, dim_inner)) + # positional embeddings self.frame_embs = ( - nn.Parameter(torch.randn(max_num_frames, dim)) + nn.Parameter(torch.randn(max_num_frames, dim_inner)) if exists(max_num_frames) else None ) self.media_time_embs = ( - nn.Parameter(torch.randn(max_num_media, 1, dim)) + nn.Parameter(torch.randn(max_num_media, 1, dim_inner)) if exists(max_num_media) else None ) @@ -96,13 +137,18 @@ def __init__( self.layers.append( nn.ModuleList( [ - PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), - FeedForward(dim=dim, mult=ff_mult), + PerceiverAttention( + dim=dim_inner, dim_head=dim_head, heads=heads + ), + FeedForward(dim=dim_inner, mult=ff_mult), ] ) ) - self.norm = nn.LayerNorm(dim) + self.norm = nn.LayerNorm(dim_inner) + + self.num_tokens_per_media = num_latents + self.dim_media = dim_inner def forward(self, x): """ @@ -114,6 +160,9 @@ def forward(self, x): """ b, T, F, v = x.shape[:4] + if exists(self.projection): + x = self.projection(x) + # frame and media time embeddings if exists(self.frame_embs): frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) @@ -132,6 +181,18 @@ def forward(self, x): return self.norm(latents) +class LinearProjection(nn.Module): + """Linear projection from patch features to image tokens.""" + + def __init__(self, *, dim, dim_out): + super().__init__() + self.proj = nn.Linear(dim, dim_out) + self.out_dim = dim_out + + def forward(self, x): + return self.proj(x) + + # gated cross attention class MaskedCrossAttention(nn.Module): def __init__( @@ -157,7 +218,7 @@ def __init__( # whether for text to only attend to immediate preceding image, or all previous images self.only_attend_immediate_media = only_attend_immediate_media - def forward(self, x, media, media_locations=None, use_cached_media=False): + def forward(self, x, media, media_locations=None): """ Args: x (torch.Tensor): text features @@ -165,19 +226,16 @@ def forward(self, x, media, media_locations=None, use_cached_media=False): media (torch.Tensor): image features shape (B, T_img, n, D_img) where n is the dim of the latents media_locations: boolean mask identifying the media tokens in x - shape (B, T_txt) - use_cached_media: bool - If true, treat all of x as if they occur after the last media - registered in media_locations. T_txt does not need to exactly - equal media_locations.shape[1] in this case + shape (B, T_txt_all) + T_txt_all >= T_txt + If T_txt_all > T_txt, then the last T_txt text_times are used """ - if not use_cached_media: - assert ( - media_locations.shape[1] == x.shape[1] - ), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" - T_txt = x.shape[1] + assert ( + T_txt <= media_locations.shape[1] + ), "current text cannot be longer than conditioned media locations" + _, T_img, n = media.shape[:3] h = self.heads @@ -196,16 +254,8 @@ def forward(self, x, media, media_locations=None, use_cached_media=False): if exists(media_locations): media_time = torch.arange(T_img, device=x.device) + 1 - if use_cached_media: - # text time is set to the last cached media location - text_time = repeat( - torch.count_nonzero(media_locations, dim=1), - "b -> b i", - i=T_txt, - ) - else: - # at each boolean of True, increment the time counter (relative to media time) - text_time = media_locations.cumsum(dim=-1) + # at each boolean of True, increment the time counter (relative to media time) + text_time = media_locations.cumsum(dim=-1)[:, -T_txt:] # text time must equal media time if only attending to most immediate image # otherwise, as long as text time is greater than media time (if attending to all previous images / media) @@ -262,14 +312,12 @@ def forward( x, media, media_locations=None, - use_cached_media=False, ): x = ( self.attn( x, media, media_locations=media_locations, - use_cached_media=use_cached_media, ) * self.attn_gate.tanh() + x @@ -277,3 +325,196 @@ def forward( x = self.ff(x) * self.ff_gate.tanh() + x return x + + +# Both DecoupledEmbedding and DecoupledLinear are taken from https://github.com/huggingface/transformers/blob/v4.32.1/src/transformers/models/idefics/modeling_idefics.py and renamed for clarity +class DecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, + then it will create `num_additional_embeddings` additional parameters that are always trained. If + `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze=True, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + Args: + num_embeddings (`int`): + Size of the dictionary of embeddings + num_additional_embeddings (`int`): + Number of additional embeddings. Only useful when you `partially_freeze=True`. + embedding_dim (`int`): + The size of each embedding vector + partially_freeze: (`bool`, *optional*, defaults to `True`): + If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. + padding_idx (`int`, *optional*): + The padding index (needs to be less than num_embeddings) + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, + `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError( + f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}" + ) + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.num_additional_embeddings = num_additional_embeddings + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + self.set_requires_grad( + require_regular_grad=not partially_freeze, require_additional_grad=True + ) + + def set_requires_grad(self, require_regular_grad, require_additional_grad): + """ + Helper function to separately set the requires_grad flag for the regular weight and the additional weight. + """ + self.weight.requires_grad_(require_regular_grad) + self.additional_embedding.requires_grad_(require_additional_grad) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd + embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but + then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - + i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are + usually relatively short it's probably not faster or if faster not by much - but might be a good idea to + measure. + + """ + if self.num_additional_embeddings == 0: + return F.embedding(input_ids, self.weight) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding( + input_ids_additional_vocab - self.num_embeddings + ) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.num_embeddings, + self.num_additional_embeddings, + self.embedding_dim, + (not self.weight.requires_grad), + ) + + +class DecoupledLinear(nn.Linear): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, + then it will create `out_additional_features * in_features` additional parameters that are always trained. If + `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + out_additional_features: int = 0, + bias: bool = True, + partially_freeze: bool = True, + device=None, + dtype=None, + ) -> None: + """ + out_additional_features: int. Number of additional trainable dimensions. Only makes sense when + `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra + parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear. + """ + super().__init__(in_features, out_features, bias, device, dtype) + self.out_additional_features = out_additional_features + + self.in_features = in_features + self.out_features = out_features + self.has_bias = bias + if out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=in_features, + out_features=out_additional_features, + bias=self.has_bias, + device=device, + dtype=dtype, + ) + self.set_requires_grad( + require_regular_grad=not partially_freeze, require_additional_grad=True + ) + + def set_requires_grad(self, require_regular_grad, require_additional_grad): + """ + Helper function to separately set the requires_grad flag for the regular weight and the additional weight. + """ + self.weight.requires_grad_(require_regular_grad) + if self.has_bias: + self.bias.requires_grad_(require_regular_grad) + self.additional_fc.requires_grad_(require_additional_grad) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = F.linear(input, self.weight, self.bias) + + if self.out_additional_features > 0: + additional_features = F.linear( + input, self.additional_fc.weight, self.additional_fc.bias + ) + output = torch.cat((output, additional_features), -1) + + return output + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( + self.in_features, + self.out_features, + self.out_additional_features, + self.bias is not None, + (not self.weight.requires_grad or not self.bias.requires_grad), + ) diff --git a/open_flamingo/src/kosmos.py b/open_flamingo/src/kosmos.py new file mode 100644 index 00000000..81ad44f8 --- /dev/null +++ b/open_flamingo/src/kosmos.py @@ -0,0 +1,46 @@ +from torch import nn +from .helpers import PerceiverResampler +from .vlm import VLMWithLanguageStream + + +class Kosmos(VLMWithLanguageStream): + def __init__( + self, + vision_encoder: nn.Module, + lang_model: nn.Module, + vis_feature_dim: int, + tokenizer_vocab_size: int, + pad_token: str, + gradient_checkpointing: bool = False, + ): + """ + Args: + vision_encoder (nn.Module): HF CLIPModel + lang_encoder (nn.Module): HF causal language model + vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder + tokenizer_vocab_size (int): size of the tokenizer vocab + padding_token_id (int): id of the padding token. None if no padding token; then a padding token + will be inserted into self.special_tokens, which factory.py fills after creating new tokens + gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False. + """ + self._special_tokens = { + "media_token": "", + "pad_token": pad_token, + } + lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1] + super().__init__( + vision_encoder=vision_encoder, + vision_tokenizer=PerceiverResampler( + dim=vis_feature_dim, dim_inner=lang_embedding_dim + ), + lang_model=lang_model, + tokenizer_vocab_size=tokenizer_vocab_size, + gradient_checkpointing=gradient_checkpointing, + ) + + def set_trainable(self): + """ + Unfreeze everything except the vision_encoder + """ + self.requires_grad_(True) + self.vision_encoder.requires_grad_(False) diff --git a/open_flamingo/src/mllm.py b/open_flamingo/src/mllm.py deleted file mode 100644 index 5492e126..00000000 --- a/open_flamingo/src/mllm.py +++ /dev/null @@ -1,372 +0,0 @@ -import torch -from einops import rearrange -from torch import nn -from typing import Optional -from torch.nn import CrossEntropyLoss -from .helpers import PerceiverResampler -from torch.distributed.fsdp.wrap import ( - enable_wrap, - wrap, -) -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, -) -from .utils import apply_with_stopping_condition - -def torch_stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"): - max_tokens = max(tensor.size(0) for tensor in list_of_tensors) - padded_tensors = [] - for tensor in list_of_tensors: - num_tokens = tensor.size(0) - if len(tensor.size()) == 1: - padding = torch.full( - (max_tokens - num_tokens,), - padding_value, - dtype=tensor.dtype, - device=tensor.device, - ) - else: - padding = torch.full( - (max_tokens - num_tokens, tensor.size(1)), - padding_value, - dtype=tensor.dtype, - device=tensor.device, - ) - padded_tensor = ( - torch.cat((tensor, padding), dim=0) - if padding_side == "right" - else torch.cat((padding, tensor), dim=0) - ) - padded_tensors.append(padded_tensor) - return torch.stack(padded_tensors) - - -class MLLM(nn.Module): - def __init__( - self, language_model, vision_model, vis_dim, media_token_id, padding_token_id - ): - super().__init__() - self.language_model = language_model - self.vision_model = vision_model.visual - self.vis_dim = vis_dim - self.media_token_id = media_token_id - self.padding_token_id = padding_token_id - self.perceiver = PerceiverResampler(dim=self.vis_dim) - self.language_projection = nn.Linear( - self.vis_dim, self.language_model.config.hidden_size - ) - - def _encode_vision_x(self, vision_x: torch.Tensor): - """ - Compute media tokens from vision input by passing it through vision encoder and conditioning language model. - Args: - vision_x (torch.Tensor): Vision input - shape (B, T_img, F, C, H, W) - Images in the same chunk are collated along T_img, and frames are collated along F - Currently only F=1 is supported (single-frame videos) - - rearrange code based on https://github.com/dhansmair/flamingo-mini - """ - - assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" - b, T, F = vision_x.shape[:3] - - vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") - with torch.no_grad(): - vision_x = self.vision_model(vision_x)[1] - vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) - vision_x = self.perceiver(vision_x) - language_model_inputs = self.language_projection(vision_x) - - return language_model_inputs - - def forward( - self, - vision_x: torch.FloatTensor, - lang_x: torch.FloatTensor, - attention_mask: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - labels: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - ): - return_dict = ( - return_dict if return_dict is not None else False - ) - - vision_x = self._encode_vision_x(vision_x) - - lang_embeds = self.language_model.get_input_embeddings()(lang_x) - - if attention_mask is None: - attention_mask = torch.ones_like(lang_x) - - labels = lang_x if labels is None else labels - - if vision_x is not None: - multimodal_embeds = [] - multimodal_labels = [] - multimodal_attention_mask = [] - - for i in range(lang_embeds.shape[0]): - # get index of tokens in lang_x[i] - image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0] - # since an image is represented by 64 tokens, we need to offset the image_token_idxs by 64 except for the first image - for j, img_idx in enumerate(image_token_idxs): - image_token_idxs[j] += 63 * j - - new_embed = lang_embeds[i].clone() - new_attention_mask = ( - attention_mask[i].clone() if attention_mask is not None else None - ) - new_label = labels[i].clone() - - for img_num, img_idx in enumerate(image_token_idxs): - new_embed = torch.cat( - ( - new_embed[:img_idx], - vision_x[i][img_num], - new_embed[img_idx + 1 :], - ), - dim=0, - ) - - new_attention_mask = torch.cat( - ( - new_attention_mask[:img_idx], - torch.ones(64, dtype=torch.long).to(attention_mask.device), - new_attention_mask[img_idx + 1 :], - ), - dim=0, - ) - - new_label = torch.cat( - ( - new_label[:img_idx], - torch.ones(64, dtype=torch.long).to(labels.device) * -100, - new_label[img_idx + 1 :], - ), - dim=0, - ) - - multimodal_embeds.append(new_embed) - multimodal_attention_mask.append(new_attention_mask) - multimodal_labels.append(new_label) - - multimodal_embeds = torch_stack_with_padding( - multimodal_embeds, padding_value=self.padding_token_id - ) - multimodal_attention_mask = torch_stack_with_padding( - multimodal_attention_mask, padding_value=0 - ) - multimodal_labels = torch_stack_with_padding( - multimodal_labels, padding_value=-100 - ) - else: - multimodal_embeds = lang_embeds - multimodal_attention_mask = attention_mask - multimodal_labels = labels - - outputs = self.language_model( - inputs_embeds=multimodal_embeds, - attention_mask=multimodal_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - logits = outputs.logits if return_dict else outputs[0] - loss = None - # we compute the loss here since we need to take into account the sequence length of the query embeds - if multimodal_labels is not None: - multimodal_labels = multimodal_labels.to(logits.device) - logits = logits[:, -multimodal_labels.size(1) :, :] - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = multimodal_labels[..., 1:].contiguous().to(logits.device) - - # Flatten the tokens - loss_fct = CrossEntropyLoss(reduction="mean") - - loss = loss_fct( - shift_logits.view(-1, self.language_model.config.vocab_size), - shift_labels.view(-1), - ) - - return (loss, logits) if loss is not None else (logits,) - - @torch.no_grad() - def generate( - self, - vision_x: torch.FloatTensor, - lang_x: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, - **generate_kwargs, - ) -> torch.LongTensor: - batch_size = vision_x.shape[0] - - if attention_mask is None: - attention_mask = torch.ones(lang_x.shape, dtype=torch.long).to( - lang_x.device - ) - - vision_x = self._encode_vision_x(vision_x) - - lang_embeds = self.language_model.get_input_embeddings()(lang_x) - - if vision_x is not None: - multimodal_embeds = [] - multimodal_attention_mask = [] - - for i in range(lang_embeds.shape[0]): - # get index of tokens in lang_x[i] - image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0] - # since an image is represented by 64 tokens, we need to offset the image_token_idxs by 64 except for the first image - for j, img_idx in enumerate(image_token_idxs): - image_token_idxs[j] += 63 * j - - new_embed = lang_embeds[i].clone() - new_attention_mask = ( - attention_mask[i].clone() if attention_mask is not None else None - ) - - for img_num, img_idx in enumerate(image_token_idxs): - new_embed = torch.cat( - ( - new_embed[:img_idx], - vision_x[i][img_num], - new_embed[img_idx + 1 :], - ), - dim=0, - ) - - new_attention_mask = torch.cat( - ( - new_attention_mask[:img_idx], - torch.ones(64, dtype=torch.long).to(attention_mask.device), - new_attention_mask[img_idx + 1 :], - ), - dim=0, - ) - - multimodal_embeds.append(new_embed) - multimodal_attention_mask.append(new_attention_mask) - - multimodal_embeds = torch_stack_with_padding( - multimodal_embeds, - padding_value=self.padding_token_id, - padding_side="left", - ) - multimodal_attention_mask = torch_stack_with_padding( - multimodal_attention_mask, padding_value=0, padding_side="left" - ) - else: - multimodal_embeds = lang_embeds - multimodal_attention_mask = attention_mask - - outputs = self.language_model.generate( - input_ids=None, - inputs_embeds=multimodal_embeds, - attention_mask=multimodal_attention_mask, - **generate_kwargs, - ) - - return outputs - - def wrap_fsdp(self, wrapper_kwargs, device_id, lm_trainable=False): - """ - Manually wraps submodules for FSDP and move other parameters to device_id. - - Why manually wrap? - - all parameters within the FSDP wrapper must have the same requires_grad. - We have a mix of frozen and unfrozen parameters. - - model.vision_model.visual needs to be individually wrapped or encode_vision_x errors - See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344 - - The rough wrapping structure is: - - FlamingoModel - - FSDP(FSDP(vision_model)) - - FSDP(FSDP(perceiver)) - - language_model - - FSDP(FSDP(input_embeddings)) - - FlamingoLayers - - FSDP(FSDP(gated_cross_attn_layer)) - - FSDP(FSDP(decoder_layer)) - - FSDP(FSDP(output_embeddings)) - - other parameters - - Known issues: - - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied, - train with DDP or set the --freeze_lm_embeddings flag to true. - - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound. - Although the training curves look okay, we found that downstream performance dramatically - degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M). - - FAQs about our FSDP wrapping strategy: - Why double wrap? - As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook - only free gathered parameters if the module is NOT FSDP root. - - Why unfreeze the decoder_layers? - See https://github.com/pytorch/pytorch/issues/95805 - As of torch==2.0.1, FSDP's _post_backward_hook is only registed if the flat param - requires_grad=True. We need the postback to fire to avoid OOM. - To effectively freeze the decoder layers, we exclude them from the optimizer. - - What is assumed to be frozen v. unfrozen? - We assume that the model is being trained under normal Flamingo settings - with these lines being called in factory.py: - ``` - # Freeze all parameters - model.requires_grad_(False) - assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 - - # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings - model.perceiver.requires_grad_(True) - model.language_model.gated_cross_attn_layers.requires_grad_(True) - [optional] model.language_model.get_input_embeddings().requires_grad_(True) - ``` - """ - # unfreeze the decoder layers - for p in self.language_model.parameters(): - p.requires_grad_(True) - - # wrap in FSDP - with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): - self.language_projection = wrap(wrap(self.language_projection)) - self.perceiver = wrap(wrap(self.perceiver)) - - self.language_model.model.layers = nn.ModuleList( - wrap(wrap(block)) for block in self.language_model.model.layers - ) - self.language_model.set_input_embeddings( - wrap(wrap(self.language_model.get_input_embeddings())) - ) - self.language_model.set_output_embeddings( - wrap(wrap(self.language_model.get_output_embeddings())) - ) - self.vision_model = wrap(wrap(self.vision_model)) # frozen - - # manually move non-FSDP managed parameters to device_id - # these are all in language_model - apply_with_stopping_condition( - module=self.language_model, - apply_fn=lambda m: m.to(device_id), - apply_condition=lambda m: len(list(m.children())) == 0, - stopping_condition=lambda m: isinstance(m, FSDP), - ) - - if not lm_trainable: - # exclude the original decoder layers from the optimizer - for p in self.language_model.parameters(): - p.exclude_from_optimizer = True - - # set up clip_grad_norm_ function - def clip_grad_norm_(max_norm): - self.perceiver.clip_grad_norm_(max_norm) - self.language_projection.clip_grad_norm_(max_norm) - # TODO: clip the decoder layers if they are unfrozen - if lm_trainable: - self.language_model.parameters().clip_grad_norm_(max_norm) - - self.clip_grad_norm_ = clip_grad_norm_ diff --git a/open_flamingo/src/utils.py b/open_flamingo/src/utils.py index 78952646..a77437f3 100644 --- a/open_flamingo/src/utils.py +++ b/open_flamingo/src/utils.py @@ -1,3 +1,6 @@ +import torch + + def extend_instance(obj, mixin): """Apply mixins to a class instance after creation""" base_cls = obj.__class__ @@ -46,3 +49,40 @@ def apply_with_stopping_condition( stopping_condition=stopping_condition, **other_args ) + + +def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"): + """ + Stack a list of tensors with padding on one side + Args: + list_of_tensors (list[torch.Tensor]): List of tensors to stack + padding_value (int, optional): Value to pad with. Defaults to 0. + padding_side (str, optional): Side to pad on. Defaults to "right". + Returns: + torch.Tensor: Stacked tensors + """ + max_tokens = max(tensor.size(0) for tensor in list_of_tensors) + padded_tensors = [] + for tensor in list_of_tensors: + num_tokens = tensor.size(0) + if len(tensor.size()) == 1: + padding = torch.full( + (max_tokens - num_tokens,), + padding_value, + dtype=tensor.dtype, + device=tensor.device, + ) + else: + padding = torch.full( + (max_tokens - num_tokens, tensor.size(1)), + padding_value, + dtype=tensor.dtype, + device=tensor.device, + ) + padded_tensor = ( + torch.cat((tensor, padding), dim=0) + if padding_side == "right" + else torch.cat((padding, tensor), dim=0) + ) + padded_tensors.append(padded_tensor) + return torch.stack(padded_tensors) diff --git a/open_flamingo/src/vlm.py b/open_flamingo/src/vlm.py new file mode 100644 index 00000000..6142e530 --- /dev/null +++ b/open_flamingo/src/vlm.py @@ -0,0 +1,595 @@ +import torch +from einops import rearrange +from torch import nn +from typing import List, Optional, Tuple, Union +from .utils import extend_instance, stack_with_padding +from .cross_attn_lm import CrossAttentionMixin +from .helpers import DecoupledEmbedding, DecoupledLinear, VLMOutputWithPast +from transformers.modeling_outputs import CausalLMOutputWithPast + + +class VLM(nn.Module): + """ + Generic vision-language model (VLM) class. + A VLM consists of four components: + 1. A vision encoder that extracts features from pixels, e.g. CLIP + input: (B, T_img, F, C, H, W) + output: (B, T_img, F, v, d) + 2. An image tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head + input: (B, T_img, F, v, d) + output: (B, T_img, n, d) + 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence + 4. A language model + """ + + def __init__( + self, + vision_encoder: nn.Module, + vision_tokenizer: nn.Module, + lang_model: nn.Module, + tokenizer_vocab_size: int, + gradient_checkpointing: bool = False, + ): + """ + Args: + vision_encoder (nn.Module): e.g. CLIP + vision_tokenizer (nn.Module): e.g. PerceiverResampler + lang_model (nn.Module): e.g. MPT + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. + cross_attn_every_n_layers (int, optional): If using cross-attention, perform cross-attention every n layers. Defaults to None. + """ + super().__init__() + + # save dimension information + self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1] + if hasattr(lang_model.config, "d_model"): + self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model + else: + self.lang_hidden_dim = lang_model.config.hidden_size + self.vis_embedding_dim = vision_tokenizer.dim_media + self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media + + # core components + self.vision_encoder = vision_encoder + self.vision_tokenizer = vision_tokenizer + self.lang_model = lang_model + + # lm embeddings + input_embeds = DecoupledEmbedding( + num_embeddings=tokenizer_vocab_size, + num_additional_embeddings=len(self.special_tokens), + embedding_dim=self.lang_embedding_dim, + ) + input_embeds.weight = self.lang_model.get_input_embeddings().weight + self.lang_model.set_input_embeddings(input_embeds) + + lang_model_bias = self.lang_model.get_output_embeddings().bias + out_embeds = DecoupledLinear( + in_features=self.lang_embedding_dim, + out_features=tokenizer_vocab_size, + bias=lang_model_bias is not None, + out_additional_features=len(self.special_tokens), + ) + if lang_model_bias is not None: + out_embeds.bias = lang_model_bias + + out_embeds.weight = self.lang_model.get_output_embeddings().weight + self.lang_model.set_output_embeddings(out_embeds) + + # gradient checkpointing + self._use_gradient_checkpointing = gradient_checkpointing + self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing + + def forward( + self, + vision_x: Optional[torch.Tensor], + lang_x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + past_key_values: Optional[ + List[Union[torch.Tensor, Tuple[torch.Tensor]]] + ] = None, + past_media_locations: Optional[torch.Tensor] = None, + past_vision_tokens: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + **kwargs, + ): + """ + Args: + vision_x: Vision input + shape (B, T_img, F, C, H, W) with F=1 + only F = 1 is supported (single-frame videos) + if T_img > the number of media tokens in the corresponding input_ids (lang_x), + only the first number of media tokens in lang_x are used + lang_x: Language input ids, with media tokens denoting where + visual media should be inserted. + shape (B, T_txt) + attention_mask: Attention mask. Defaults to None. + labels: Labels. Defaults to None. + shape (B, T_txt) + past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None. + list of length = number of decoder layers in the LM + exact implementation depends on LM, see Hugging Face docs + past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None. + shape (B, T_txt) + past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None. + use_cache (Optional[bool], optional): Whether to use cache. Defaults to False. + If True, includes key_values, media_locations, and vision_tokens in the output. + """ + assert not (past_vision_tokens is None) ^ ( + past_media_locations is None + ), "past_vision_tokens and past_media_locations must both be None or both be not None" + + # convert pixels to vision tokens + if vision_x is not None: + vision_features = self._encode_vision_x(vision_x=vision_x) + vision_tokens = self.vision_tokenizer(vision_features) + else: + vision_tokens = None + + # fuse the vision and language tokens + new_inputs = self._prepare_inputs_for_forward( + vision_tokens=vision_tokens, + lang_x=lang_x, + attention_mask=attention_mask, + labels=labels, + past_key_values=past_key_values, + past_media_locations=past_media_locations, + past_vision_tokens=past_vision_tokens, + ) + output = self.lang_model( + **new_inputs, + use_cache=use_cache, + past_key_values=past_key_values, + **kwargs, + ) + + # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream + # or to add the past_vision_tokens and past_media_locations to the output + output = self._postprocess_outputs_from_forward( + output=output, + lang_x=lang_x, + vision_tokens=vision_tokens, + use_cache=use_cache, + past_vision_tokens=past_vision_tokens, + past_media_locations=past_media_locations, + ) + + # postforward hooks + self._post_forward_hook() + return output + + def _encode_vision_x(self, vision_x: torch.Tensor): + """ + Compute media tokens from vision input by passing it through vision encoder and conditioning language model. + Args: + vision_x: Vision input + shape (B, T_img, F, C, H, W) + Images in the same chunk are collated along T_img, and frames are collated along F + Currently only F=1 is supported (single-frame videos) + + rearrange code based on https://github.com/dhansmair/flamingo-mini + """ + assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" + b, T, F = vision_x.shape[:3] + + vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") + with torch.no_grad(): + vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples + vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) + return vision_x + + def generate( + self, + vision_x: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor = None, + past_key_values: Optional[ + List[Union[torch.Tensor, Tuple[torch.Tensor]]] + ] = None, + past_media_locations: Optional[torch.Tensor] = None, + past_vision_tokens: Optional[torch.Tensor] = None, + **kwargs, + ): + """ + Generate text conditioned on vision and language inputs. + Args: + vision_x (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + see documentation for forward + lang_x (torch.Tensor): Language input + shape (B, T_txt) + attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. + **kwargs: see generate documentation in Hugging Face CausalLM models. + Returns: + torch.Tensor: lang_x with generated tokens appended to it + """ + num_beams = kwargs.pop("num_beams", 1) + if num_beams > 1: + vision_x = vision_x.repeat_interleave(num_beams, dim=0) + + # convert pixels to vision tokens + if vision_x is not None: + vision_features = self._encode_vision_x(vision_x=vision_x) + vision_tokens = self.vision_tokenizer(vision_features) + else: + vision_tokens = None + + # fuse the vision and language tokens + new_inputs = self._prepare_inputs_for_forward( + vision_tokens=vision_tokens, + lang_x=lang_x, + attention_mask=attention_mask, + past_key_values=past_key_values, + past_media_locations=past_media_locations, + past_vision_tokens=past_vision_tokens, + ) + output = self.lang_model.generate( + **new_inputs, + past_key_values=past_key_values, + past_media_locations=past_media_locations, + past_vision_tokens=past_vision_tokens, + num_beams=num_beams, + **kwargs, + ) + + self._post_forward_hook() + return output + + @property + def num_trainable_params(self): + """Print the number of trainable parameters""" + return sum(p.numel() for p in self.parameters() if p.requires_grad) + + def set_trainable(self): + """ + Freeze appropriate parameters in the model. + """ + raise NotImplementedError + + @property + def special_tokens(self): + """ + Returns a dict mapping from the attribute name of a special token to its string format, + e.g. "media_token": "" + """ + assert ( + "media_token" in self._special_tokens + ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id" + assert ( + "pad_token" in self._special_tokens + ), "VLMs need to request that the tokenizer call set_special_token_ids and set self.pad_token_id" + return self._special_tokens + + def set_special_token_ids(self, string_to_ids): + """ + Args: + string_to_ids (dict): mapping from token string to id + """ + assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys())) + for att_name, token_str in self.special_tokens.items(): + token_id = string_to_ids[token_str] + setattr(self, f"{att_name}_id", token_id) + setattr(self.lang_model, f"{att_name}_id", token_id) + + +class VLMWithCrossAttention(VLM): + """ + VLM using cross-attention to fuse vision and language tokens. + """ + + def __init__( + self, + vision_encoder: nn.Module, + vision_tokenizer: nn.Module, + lang_model: nn.Module, + tokenizer_vocab_size: int, + gradient_checkpointing: bool = False, + decoder_layers_attr_name: str = None, + cross_attn_every_n_layers: int = None, + ): + extend_instance(lang_model, CrossAttentionMixin) + super().__init__( + vision_encoder=vision_encoder, + vision_tokenizer=vision_tokenizer, + lang_model=lang_model, + tokenizer_vocab_size=tokenizer_vocab_size, + gradient_checkpointing=gradient_checkpointing, + ) + self.lang_model.set_decoder_layers_attr_name(decoder_layers_attr_name) + self.lang_model.init_cross_attention_layers( + lang_hidden_size=self.lang_hidden_dim, + vis_hidden_size=self.vis_embedding_dim, + cross_attn_every_n_layers=cross_attn_every_n_layers, + gradient_checkpointing=gradient_checkpointing, + ) + + def _prepare_inputs_for_forward( + self, + vision_tokens: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor, + labels: torch.Tensor = None, + past_key_values=None, + past_media_locations: torch.Tensor = None, + past_vision_tokens: torch.Tensor = None, + ): + """Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence""" + self.lang_model._condition_media_before_forward( + input_ids=lang_x, + vision_tokens=vision_tokens, + past_media_locations=past_media_locations, + past_vision_tokens=past_vision_tokens, + ) + return { + "input_ids": lang_x, + "attention_mask": attention_mask, + "labels": labels, + } + + def _postprocess_outputs_from_forward( + self, + output: CausalLMOutputWithPast, + lang_x: torch.Tensor, + vision_tokens: torch.Tensor, + past_vision_tokens: torch.Tensor, + past_media_locations: torch.Tensor, + use_cache: bool = False, + ): + """Include the past vision tokens and past media locations in the output""" + if use_cache: + if past_media_locations is not None and past_vision_tokens is not None: + if vision_tokens is not None: + updated_vision_tokens = torch.cat( + [ + past_vision_tokens, + vision_tokens, + ], + dim=1, + ) + else: + updated_vision_tokens = past_vision_tokens + updated_media_locations = torch.cat( + [ + past_media_locations, + lang_x == self.media_token_id, + ], + dim=1, + ) + else: + updated_vision_tokens = vision_tokens + updated_media_locations = lang_x == self.media_token_id + + else: + updated_vision_tokens = None + updated_media_locations = None + + output = VLMOutputWithPast( + loss=output.loss, + logits=output.logits, + past_key_values=output.past_key_values, + hidden_states=output.hidden_states, + attentions=output.attentions, + past_media_locations=updated_media_locations, + past_vision_tokens=updated_vision_tokens, + ) + + return output + + def _post_forward_hook(self): + # clear the conditioned layers + self.lang_model.clear_conditioned_layers() + + +class VLMWithLanguageStream(VLM): + """ + VLM that fuses modalities by inserting vision tokens directly into the language stream. + """ + + def __init__( + self, + vision_encoder: nn.Module, + vision_tokenizer: nn.Module, + lang_model: nn.Module, + tokenizer_vocab_size: int, + gradient_checkpointing: bool = False, + ): + super().__init__( + vision_encoder=vision_encoder, + vision_tokenizer=vision_tokenizer, + lang_model=lang_model, + tokenizer_vocab_size=tokenizer_vocab_size, + gradient_checkpointing=gradient_checkpointing, + ) + assert ( + self.vis_embedding_dim == self.lang_embedding_dim + ), "To place visual tokens direclty in the language stream, the visual and language tokens need to be the same dim." + + def _prepare_inputs_for_forward( + self, + vision_tokens: torch.Tensor, + lang_x: torch.Tensor, + attention_mask: torch.Tensor, + labels: torch.Tensor = None, + past_key_values=None, + past_media_locations: torch.Tensor = None, + past_vision_tokens: torch.Tensor = None, + ): + """ + Insert the vision tokens directly into the language stream/ + This requires us to modify the input_ids, attention_mask, and labels. + """ + # handle past_key_values + B, _ = lang_x.shape + if past_key_values is not None: + past_len = past_key_values[0][0].shape[2] + attention_mask = torch.cat( + [ + torch.ones(B, past_len, dtype=torch.long).to( + attention_mask.device + ), # TODO: not sure these should all be 1 + attention_mask, + ], + dim=1, + ) + + if vision_tokens is None: + return { + "input_ids": lang_x, + "attention_mask": attention_mask, + "labels": labels, + } + + # get the language embeddings + lang_embeds = self.lang_model.get_input_embeddings()(lang_x) + + # build up the multimodal embeddings + has_labels = labels is not None + multimodal_embeds = [] + multimodal_attention_mask = [] + multimodal_labels = [] if has_labels else None + for i in range(B): + # get index of tokens in lang_x[i] + image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0] + + # since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs + for j, img_idx in enumerate(image_token_idxs): + image_token_idxs[j] += (self.num_tokens_per_vis - 1) * j + + # loop through the image_token_idxs and insert the vision tokens + new_embed = lang_embeds[i].clone() + new_attention_mask = ( + attention_mask[i].clone() if attention_mask is not None else None + ) + if has_labels: + new_label = labels[i].clone() + + for img_num, img_idx in enumerate(image_token_idxs): + new_embed = torch.cat( + ( + new_embed[:img_idx], + vision_tokens[i][img_num], + new_embed[img_idx + 1 :], + ), + dim=0, + ) + new_attention_mask = torch.cat( + ( + new_attention_mask[:img_idx], + torch.ones(self.num_tokens_per_vis, dtype=torch.long).to( + attention_mask.device + ), + new_attention_mask[img_idx + 1 :], + ), + dim=0, + ) + if has_labels: + new_label = torch.cat( + ( + new_label[:img_idx], + torch.ones(self.num_tokens_per_vis, dtype=torch.long).to( + labels.device + ) + * -100, + new_label[img_idx + 1 :], + ), + dim=0, + ) + multimodal_embeds.append(new_embed) + multimodal_attention_mask.append(new_attention_mask) + if has_labels: + multimodal_labels.append(new_label) + + # stack + multimodal_embeds = stack_with_padding( + multimodal_embeds, padding_value=self.pad_token_id + ) + multimodal_attention_mask = stack_with_padding( + multimodal_attention_mask, padding_value=0 + ) + if has_labels: + multimodal_labels = stack_with_padding( + multimodal_labels, padding_value=-100 + ) + + return { + "inputs_embeds": multimodal_embeds, + "attention_mask": multimodal_attention_mask, + "labels": multimodal_labels, + } + + def _postprocess_outputs_from_forward( + self, + output: CausalLMOutputWithPast, + lang_x: torch.Tensor, + vision_tokens: torch.Tensor, + past_vision_tokens: torch.Tensor, + past_media_locations: torch.Tensor, + use_cache: bool = False, + ): + # Include the past vision tokens and past media locations in the output + if use_cache: + if past_media_locations is not None and past_vision_tokens is not None: + if vision_tokens is not None: + updated_vision_tokens = torch.cat( + [ + past_vision_tokens, + vision_tokens, + ], + dim=1, + ) + else: + updated_vision_tokens = past_vision_tokens + updated_media_locations = torch.cat( + [ + past_media_locations, + lang_x == self.media_token_id, + ], + dim=1, + ) + else: + updated_vision_tokens = vision_tokens + updated_media_locations = lang_x == self.media_token_id + + else: + updated_vision_tokens = None + updated_media_locations = None + + # return logits that are the same shape as the original input_ids + logits = output.logits + batch_logits = [] + B, T_txt = lang_x.shape + for i in range(B): + sequence_logits = [] + logits_j = 0 + for j in range(T_txt): + if lang_x[i, j] != self.media_token_id: + sequence_logits.append(logits[i, logits_j]) + logits_j += 1 + else: + # append the logit for the first image token, then skip over the rest + # note: the model actually learns to predict , not + sequence_logits.append(logits[i, logits_j]) + logits_j += self.num_tokens_per_vis + sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size) + batch_logits.append(sequence_logits) + + batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size) + # The final logits shape should be the same as the original input_ids shape + assert batch_logits.shape[:2] == (B, T_txt) + + # assemble the output + output = VLMOutputWithPast( + loss=output.loss, + logits=batch_logits, + past_key_values=output.past_key_values, + hidden_states=output.hidden_states, + attentions=output.attentions, + past_media_locations=updated_media_locations, + past_vision_tokens=updated_vision_tokens, + ) + + return output + + def _post_forward_hook(self): + pass