Skip to content

Commit

Permalink
Merge branch 'deepspeed' into deepspeed_inference
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla authored Sep 15, 2023
2 parents 6f62054 + 3805d0f commit 27ce458
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 101 deletions.
34 changes: 12 additions & 22 deletions open_flamingo/scripts/run_train.sh
Original file line number Diff line number Diff line change
@@ -1,42 +1,32 @@
#!/bin/bash
#SBATCH --nodes 1
#SBATCH --ntasks-per-node=6
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-task=1
#SBATCH --account=efml
#SBATCH --partition=gpu
#SBATCH --time=48:00:00
#SBATCH --job-name=flamingo

export PYTHONFAULTHANDLER=1
export CUDA_LAUNCH_BLOCKING=0
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=15000
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
export HF_DATASETS_CACHE="/gscratch/efml/anasa2/.huggingface" TRANSFORMERS_CACHE="/gscratch/efml/anasa2/.huggingface"

export PYTHONPATH="$PYTHONPATH:open_flamingo"
srun --cpu_bind=v --accel-bind=gn python



deepspeed open_flamingo/open_flamingo/train/train.py \
--lm_path meta-llama/Llama-2-13b \
--tokenizer_path meta-llama/Llama-2-13b \
--cross_attn_every_n_layers 4 \
srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
--cross_attn_every_n_layers 1 \
--dataset_resampled \
--batch_size_mmc4 16 \
--batch_size_laion 32 \
--deepspeed \
--batch_size_mmc4 32 \
--batch_size_laion 64 \
--train_num_samples_mmc4 125000\
--train_num_samples_laion 250000 \
--loss_multiplier_laion 0.2 \
--workers=4 \
--run_name "deepspeed" \
--run_name OpenFlamingo-3B-vitl-mpt1b \
--num_epochs 480 \
--warmup_steps 0 \
--mmc4_textsim_threshold 0.0 \
--laion_shards "/mmfs1/gscratch/efml/anasa2/laion-samples/{000000..000001}.tar" \
--mmc4_shards "/mmfs1/gscratch/efml/anasa2/mmc4-samples/shard_{0..1}-000000000.tar" \
--warmup_steps 1875 \
--mmc4_textsim_threshold 0.24 \
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
--gradient_checkpointing \
--report_to_wandb \
69 changes: 56 additions & 13 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from transformers import AutoModelForCausalLM, AutoTokenizer
import open_clip
import torch.nn as nn

from .flamingo import Flamingo
from .flamingo_lm import FlamingoLMMixin
Expand All @@ -14,9 +15,9 @@ def create_model_and_transforms(
lang_encoder_path: str,
tokenizer_path: str,
cross_attn_every_n_layers: int = 1,
untie_embeddings: bool = False,
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
freeze_lm_embeddings: bool = False,
cache_dir: Optional[str] = None,
**flamingo_kwargs,
):
Expand All @@ -30,9 +31,9 @@ def create_model_and_transforms(
lang_encoder_path (str): path to pretrained language encoder
tokenizer_path (str): path to pretrained tokenizer
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
untie_embeddings (bool, optional): whether to untie the input and output embeddings if they are tied. This is required for training using FSDP. Defaults to False.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Expand All @@ -50,24 +51,39 @@ def create_model_and_transforms(
text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
# trust_remote_code=True
)

# add Flamingo special tokens to the tokenizer
text_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endofchunk|>", "<image>"]}
)
if text_tokenizer.pad_token is None:
new_tokens = 2
if text_tokenizer.pad_token is None and text_tokenizer.pad_token_id is None: # need to check both because some tokenizers have a pad token id but not a pad token
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})

text_tokenizer.pad_token_id = text_tokenizer.eos_token_id

# text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
# new_tokens += 1

ids_for_additional_special_tokens = text_tokenizer.convert_tokens_to_ids(
["<|endofchunk|>","<image>","<PAD>"] if new_tokens == 3 else ["<|endofchunk|>", "<image>"]
)

lang_encoder = AutoModelForCausalLM.from_pretrained(
lang_encoder_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
# trust_remote_code=True
)

lang_encoder.config.update({"original_vocab_size": min(ids_for_additional_special_tokens)})
lang_encoder.config.vocab_size = max(len(text_tokenizer), lang_encoder.config.vocab_size)

# change model's vocab size to include new tokens
lang_encoder.config.vocab_size = len(text_tokenizer)

# hacks for MPT-1B, which doesn't have a get_input_embeddings method
if "mpt-1b-redpajama-200b" in lang_encoder_path:
Expand All @@ -81,6 +97,28 @@ def set_input_embeddings(self, new_embeddings):

extend_instance(lang_encoder, EmbeddingFnMixin)

if not hasattr(lang_encoder, "get_output_embeddings"):
if hasattr(lang_encoder, "lm_head"):
lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head
else:
raise ValueError(
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
)

if not hasattr(lang_encoder, "set_output_embeddings"):
if hasattr(lang_encoder, "lm_head"):
lang_encoder.set_output_embeddings = lambda x: setattr(
lang_encoder, "lm_head", x
)
else:
raise ValueError(
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
)

if untie_embeddings:
lang_encoder.get_output_embeddings().weight = nn.Parameter(lang_encoder.get_output_embeddings().weight.clone())
lang_encoder.config.update({"tie_word_embeddings": False})

# convert LM to FlamingoLM
extend_instance(lang_encoder, FlamingoLMMixin)

Expand All @@ -100,19 +138,24 @@ def set_input_embeddings(self, new_embeddings):
"width"
],
cross_attn_every_n_layers=cross_attn_every_n_layers,
new_tokens=new_tokens, # number of tokens embeddings to train
padding_token_id=text_tokenizer.pad_token_id,
**flamingo_kwargs,
)

# Freeze all parameters
model.requires_grad_(False)
assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0
model.vision_encoder.requires_grad_(False)
model.lang_encoder.requires_grad_(False)

# Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings
# Unfreeze gated_cross_attn_layers and perceiver
model.perceiver.requires_grad_(True)
model.lang_encoder.gated_cross_attn_layers.requires_grad_(True)
if not freeze_lm_embeddings:
model.lang_encoder.get_input_embeddings().requires_grad_(True)
# TODO: investigate also training the output embeddings when untied

if hasattr(model.lang_encoder.get_output_embeddings(), "additional_fc"):
model.lang_encoder.get_output_embeddings().additional_fc.requires_grad_(True)

if hasattr(model.lang_encoder.get_input_embeddings(), "additional_embedding"):
model.lang_encoder.get_input_embeddings().additional_embedding.requires_grad_(True)

print(
f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
Expand Down
38 changes: 32 additions & 6 deletions open_flamingo/src/flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ def __init__(
lang_encoder: nn.Module,
eoc_token_id: int,
media_token_id: int,
padding_token_id: int,
vis_dim: int,
# vocab_size: int,
new_tokens: int,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
):
Expand All @@ -31,9 +34,12 @@ def __init__(
lang_encoder (nn.Module): HF causal language model
eoc_token_id (int): Token id for <|endofchunk|>
media_token_id (int): Token id for <image>
padding_token_id (int): Token id for padding token
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
new_tokens (int): Number of new tokens added to the tokenizer.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
"""
super().__init__()
self.eoc_token_id = eoc_token_id
Expand All @@ -49,10 +55,12 @@ def __init__(
self.lang_encoder = lang_encoder
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
padding_token_id=padding_token_id,
lang_hidden_size=self.lang_dim,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
new_tokens=new_tokens,
)
self._use_gradient_checkpointing = gradient_checkpointing
self.perceiver._use_gradient_checkpointing = gradient_checkpointing
Expand Down Expand Up @@ -268,12 +276,30 @@ def wrap_fsdp(self, wrapper_kwargs, device_id):
for layer in self.lang_encoder.gated_cross_attn_layers
)
self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing)
self.lang_encoder.set_input_embeddings(
wrap(wrap(self.lang_encoder.get_input_embeddings()))
)
self.lang_encoder.set_output_embeddings(
wrap(wrap(self.lang_encoder.get_output_embeddings()))
)
if hasattr(self.lang_encoder.get_input_embeddings(), "additional_embedding"):
# wrap additional_embedding and original embedding separately, this is the case when using decoupled embeddings
self.lang_encoder.get_input_embeddings().additional_embedding = wrap(
wrap(self.lang_encoder.get_input_embeddings().additional_embedding)
)
self.lang_encoder.get_input_embeddings().weight = wrap(wrap(self.lang_encoder.get_input_embeddings().weight))
else:
self.lang_encoder.set_input_embeddings(
wrap(wrap(self.lang_encoder.get_input_embeddings()))
)

if hasattr(self.lang_encoder.get_output_embeddings(), "additional_fc"):
# wrap additional_fc and original embedding separately, this is the case when using decoupled linear layer
self.lang_encoder.get_output_embeddings().additional_fc = wrap(
wrap(self.lang_encoder.get_output_embeddings().additional_fc)
)
self.lang_encoder.get_output_embeddings().weight = wrap(wrap(self.lang_encoder.get_output_embeddings().weight))
if self.lang_encoder.get_output_embeddings().bias is not None:
self.lang_encoder.get_output_embeddings().bias = wrap(wrap(self.lang_encoder.get_output_embeddings().bias))
else:
self.lang_encoder.set_output_embeddings(
wrap(wrap(self.lang_encoder.get_output_embeddings()))
)

self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen

# manually move non-FSDP managed parameters to device_id
Expand Down
39 changes: 38 additions & 1 deletion open_flamingo/src/flamingo_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch.nn as nn
from .helpers import GatedCrossAttentionBlock
from .helpers import (
GatedCrossAttentionBlock,
FlamingoDecoupledEmbedding,
FlamingoDecoupledLinear,
)
from .utils import getattr_recursive, setattr_recursive


Expand Down Expand Up @@ -83,10 +87,12 @@ def _set_decoder_layers(self, value):
def init_flamingo(
self,
media_token_id,
padding_token_id,
lang_hidden_size,
vis_hidden_size,
cross_attn_every_n_layers,
gradient_checkpointing,
new_tokens,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
Expand All @@ -104,6 +110,37 @@ def init_flamingo(
)
self.init_flamingo_layers(gradient_checkpointing)
self.media_token_id = media_token_id

# modify the embedding layer to support decoupling
input_embeds = FlamingoDecoupledEmbedding(
num_embeddings=self.config.original_vocab_size,
num_additional_embeddings=new_tokens,
embedding_dim=self.config.hidden_size,
partially_freeze=True,
padding_idx=padding_token_id,
)
input_embeds.weight = self.get_input_embeddings().weight
input_embeds.additional_embedding.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
self.set_input_embeddings(input_embeds)

out_embeds = FlamingoDecoupledLinear(
in_features=self.config.hidden_size,
out_features=self.config.original_vocab_size,
bias=self.get_output_embeddings().bias is not None,
out_additional_features=new_tokens,
partially_freeze=True,
)

if self.get_output_embeddings().bias is not None:
out_embeds.bias = self.get_output_embeddings().bias

out_embeds.weight = self.get_output_embeddings().weight
out_embeds.additional_fc.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
self.set_output_embeddings(out_embeds)

if getattr(self.config, "tie_word_embeddings", False):
self.get_output_embeddings().additional_fc.weight = self.get_input_embeddings().additional_embedding.weight

self.initialized_flamingo = True
self._use_cached_vision_x = False

Expand Down
Loading

0 comments on commit 27ce458

Please sign in to comment.