diff --git a/open_flamingo/scripts/run_train.sh b/open_flamingo/scripts/run_train.sh index dce882e1..8d45355e 100644 --- a/open_flamingo/scripts/run_train.sh +++ b/open_flamingo/scripts/run_train.sh @@ -1,11 +1,7 @@ #!/bin/bash #SBATCH --nodes 1 -#SBATCH --ntasks-per-node=6 +#SBATCH --ntasks-per-node=8 #SBATCH --gpus-per-task=1 -#SBATCH --account=efml -#SBATCH --partition=gpu -#SBATCH --time=48:00:00 -#SBATCH --job-name=flamingo export PYTHONFAULTHANDLER=1 export CUDA_LAUNCH_BLOCKING=0 @@ -13,30 +9,24 @@ export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) export MASTER_PORT=15000 export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` -export HF_DATASETS_CACHE="/gscratch/efml/anasa2/.huggingface" TRANSFORMERS_CACHE="/gscratch/efml/anasa2/.huggingface" export PYTHONPATH="$PYTHONPATH:open_flamingo" -srun --cpu_bind=v --accel-bind=gn python - - - -deepspeed open_flamingo/open_flamingo/train/train.py \ - --lm_path meta-llama/Llama-2-13b \ - --tokenizer_path meta-llama/Llama-2-13b \ - --cross_attn_every_n_layers 4 \ +srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \ + --lm_path anas-awadalla/mpt-1b-redpajama-200b \ + --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ + --cross_attn_every_n_layers 1 \ --dataset_resampled \ - --batch_size_mmc4 16 \ - --batch_size_laion 32 \ - --deepspeed \ + --batch_size_mmc4 32 \ + --batch_size_laion 64 \ --train_num_samples_mmc4 125000\ --train_num_samples_laion 250000 \ --loss_multiplier_laion 0.2 \ --workers=4 \ - --run_name "deepspeed" \ + --run_name OpenFlamingo-3B-vitl-mpt1b \ --num_epochs 480 \ - --warmup_steps 0 \ - --mmc4_textsim_threshold 0.0 \ - --laion_shards "/mmfs1/gscratch/efml/anasa2/laion-samples/{000000..000001}.tar" \ - --mmc4_shards "/mmfs1/gscratch/efml/anasa2/mmc4-samples/shard_{0..1}-000000000.tar" \ + --warmup_steps 1875 \ + --mmc4_textsim_threshold 0.24 \ + --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ + --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ --gradient_checkpointing \ --report_to_wandb \ diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 0fd4ec42..1f5a08fe 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -2,6 +2,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import open_clip +import torch.nn as nn from .flamingo import Flamingo from .flamingo_lm import FlamingoLMMixin @@ -14,9 +15,9 @@ def create_model_and_transforms( lang_encoder_path: str, tokenizer_path: str, cross_attn_every_n_layers: int = 1, + untie_embeddings: bool = False, use_local_files: bool = False, decoder_layers_attr_name: str = None, - freeze_lm_embeddings: bool = False, cache_dir: Optional[str] = None, **flamingo_kwargs, ): @@ -30,9 +31,9 @@ def create_model_and_transforms( lang_encoder_path (str): path to pretrained language encoder tokenizer_path (str): path to pretrained tokenizer cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. + untie_embeddings (bool, optional): whether to untie the input and output embeddings if they are tied. This is required for training using FSDP. Defaults to False. use_local_files (bool, optional): whether to use local files. Defaults to False. decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. - freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. Returns: Flamingo: Flamingo model from pretrained vision and language encoders @@ -50,24 +51,39 @@ def create_model_and_transforms( text_tokenizer = AutoTokenizer.from_pretrained( tokenizer_path, local_files_only=use_local_files, - trust_remote_code=True, cache_dir=cache_dir, + # trust_remote_code=True ) + # add Flamingo special tokens to the tokenizer text_tokenizer.add_special_tokens( {"additional_special_tokens": ["<|endofchunk|>", ""]} ) - if text_tokenizer.pad_token is None: + new_tokens = 2 + if text_tokenizer.pad_token is None and text_tokenizer.pad_token_id is None: # need to check both because some tokenizers have a pad token id but not a pad token # Issue: GPT models don't have a pad token, which we use to # modify labels for the loss. - text_tokenizer.add_special_tokens({"pad_token": ""}) - + text_tokenizer.pad_token_id = text_tokenizer.eos_token_id + + # text_tokenizer.add_special_tokens({"pad_token": ""}) + # new_tokens += 1 + + ids_for_additional_special_tokens = text_tokenizer.convert_tokens_to_ids( + ["<|endofchunk|>","",""] if new_tokens == 3 else ["<|endofchunk|>", ""] + ) + lang_encoder = AutoModelForCausalLM.from_pretrained( lang_encoder_path, local_files_only=use_local_files, - trust_remote_code=True, cache_dir=cache_dir, + # trust_remote_code=True ) + + lang_encoder.config.update({"original_vocab_size": min(ids_for_additional_special_tokens)}) + lang_encoder.config.vocab_size = max(len(text_tokenizer), lang_encoder.config.vocab_size) + + # change model's vocab size to include new tokens + lang_encoder.config.vocab_size = len(text_tokenizer) # hacks for MPT-1B, which doesn't have a get_input_embeddings method if "mpt-1b-redpajama-200b" in lang_encoder_path: @@ -81,6 +97,28 @@ def set_input_embeddings(self, new_embeddings): extend_instance(lang_encoder, EmbeddingFnMixin) + if not hasattr(lang_encoder, "get_output_embeddings"): + if hasattr(lang_encoder, "lm_head"): + lang_encoder.get_output_embeddings = lambda: lang_encoder.lm_head + else: + raise ValueError( + "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." + ) + + if not hasattr(lang_encoder, "set_output_embeddings"): + if hasattr(lang_encoder, "lm_head"): + lang_encoder.set_output_embeddings = lambda x: setattr( + lang_encoder, "lm_head", x + ) + else: + raise ValueError( + "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py." + ) + + if untie_embeddings: + lang_encoder.get_output_embeddings().weight = nn.Parameter(lang_encoder.get_output_embeddings().weight.clone()) + lang_encoder.config.update({"tie_word_embeddings": False}) + # convert LM to FlamingoLM extend_instance(lang_encoder, FlamingoLMMixin) @@ -100,19 +138,24 @@ def set_input_embeddings(self, new_embeddings): "width" ], cross_attn_every_n_layers=cross_attn_every_n_layers, + new_tokens=new_tokens, # number of tokens embeddings to train + padding_token_id=text_tokenizer.pad_token_id, **flamingo_kwargs, ) # Freeze all parameters - model.requires_grad_(False) - assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 + model.vision_encoder.requires_grad_(False) + model.lang_encoder.requires_grad_(False) - # Unfreeze perceiver, gated_cross_attn_layers, and LM input embeddings + # Unfreeze gated_cross_attn_layers and perceiver model.perceiver.requires_grad_(True) model.lang_encoder.gated_cross_attn_layers.requires_grad_(True) - if not freeze_lm_embeddings: - model.lang_encoder.get_input_embeddings().requires_grad_(True) - # TODO: investigate also training the output embeddings when untied + + if hasattr(model.lang_encoder.get_output_embeddings(), "additional_fc"): + model.lang_encoder.get_output_embeddings().additional_fc.requires_grad_(True) + + if hasattr(model.lang_encoder.get_input_embeddings(), "additional_embedding"): + model.lang_encoder.get_input_embeddings().additional_embedding.requires_grad_(True) print( f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index 9a67cfed..61f0b6da 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -21,7 +21,10 @@ def __init__( lang_encoder: nn.Module, eoc_token_id: int, media_token_id: int, + padding_token_id: int, vis_dim: int, + # vocab_size: int, + new_tokens: int, cross_attn_every_n_layers: int = 1, gradient_checkpointing: bool = False, ): @@ -31,9 +34,12 @@ def __init__( lang_encoder (nn.Module): HF causal language model eoc_token_id (int): Token id for <|endofchunk|> media_token_id (int): Token id for + padding_token_id (int): Token id for padding token vis_dim (int): Dimension of the visual features. Visual features are projected to match this shape along the last dimension. + new_tokens (int): Number of new tokens added to the tokenizer. cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False. """ super().__init__() self.eoc_token_id = eoc_token_id @@ -49,10 +55,12 @@ def __init__( self.lang_encoder = lang_encoder self.lang_encoder.init_flamingo( media_token_id=media_token_id, + padding_token_id=padding_token_id, lang_hidden_size=self.lang_dim, vis_hidden_size=self.vis_dim, cross_attn_every_n_layers=cross_attn_every_n_layers, gradient_checkpointing=gradient_checkpointing, + new_tokens=new_tokens, ) self._use_gradient_checkpointing = gradient_checkpointing self.perceiver._use_gradient_checkpointing = gradient_checkpointing @@ -268,12 +276,30 @@ def wrap_fsdp(self, wrapper_kwargs, device_id): for layer in self.lang_encoder.gated_cross_attn_layers ) self.lang_encoder.init_flamingo_layers(self._use_gradient_checkpointing) - self.lang_encoder.set_input_embeddings( - wrap(wrap(self.lang_encoder.get_input_embeddings())) - ) - self.lang_encoder.set_output_embeddings( - wrap(wrap(self.lang_encoder.get_output_embeddings())) - ) + if hasattr(self.lang_encoder.get_input_embeddings(), "additional_embedding"): + # wrap additional_embedding and original embedding separately, this is the case when using decoupled embeddings + self.lang_encoder.get_input_embeddings().additional_embedding = wrap( + wrap(self.lang_encoder.get_input_embeddings().additional_embedding) + ) + self.lang_encoder.get_input_embeddings().weight = wrap(wrap(self.lang_encoder.get_input_embeddings().weight)) + else: + self.lang_encoder.set_input_embeddings( + wrap(wrap(self.lang_encoder.get_input_embeddings())) + ) + + if hasattr(self.lang_encoder.get_output_embeddings(), "additional_fc"): + # wrap additional_fc and original embedding separately, this is the case when using decoupled linear layer + self.lang_encoder.get_output_embeddings().additional_fc = wrap( + wrap(self.lang_encoder.get_output_embeddings().additional_fc) + ) + self.lang_encoder.get_output_embeddings().weight = wrap(wrap(self.lang_encoder.get_output_embeddings().weight)) + if self.lang_encoder.get_output_embeddings().bias is not None: + self.lang_encoder.get_output_embeddings().bias = wrap(wrap(self.lang_encoder.get_output_embeddings().bias)) + else: + self.lang_encoder.set_output_embeddings( + wrap(wrap(self.lang_encoder.get_output_embeddings())) + ) + self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen # manually move non-FSDP managed parameters to device_id diff --git a/open_flamingo/src/flamingo_lm.py b/open_flamingo/src/flamingo_lm.py index c4933e9d..6c9545b3 100644 --- a/open_flamingo/src/flamingo_lm.py +++ b/open_flamingo/src/flamingo_lm.py @@ -1,5 +1,9 @@ import torch.nn as nn -from .helpers import GatedCrossAttentionBlock +from .helpers import ( + GatedCrossAttentionBlock, + FlamingoDecoupledEmbedding, + FlamingoDecoupledLinear, +) from .utils import getattr_recursive, setattr_recursive @@ -83,10 +87,12 @@ def _set_decoder_layers(self, value): def init_flamingo( self, media_token_id, + padding_token_id, lang_hidden_size, vis_hidden_size, cross_attn_every_n_layers, gradient_checkpointing, + new_tokens, ): """ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. @@ -104,6 +110,37 @@ def init_flamingo( ) self.init_flamingo_layers(gradient_checkpointing) self.media_token_id = media_token_id + + # modify the embedding layer to support decoupling + input_embeds = FlamingoDecoupledEmbedding( + num_embeddings=self.config.original_vocab_size, + num_additional_embeddings=new_tokens, + embedding_dim=self.config.hidden_size, + partially_freeze=True, + padding_idx=padding_token_id, + ) + input_embeds.weight = self.get_input_embeddings().weight + input_embeds.additional_embedding.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + self.set_input_embeddings(input_embeds) + + out_embeds = FlamingoDecoupledLinear( + in_features=self.config.hidden_size, + out_features=self.config.original_vocab_size, + bias=self.get_output_embeddings().bias is not None, + out_additional_features=new_tokens, + partially_freeze=True, + ) + + if self.get_output_embeddings().bias is not None: + out_embeds.bias = self.get_output_embeddings().bias + + out_embeds.weight = self.get_output_embeddings().weight + out_embeds.additional_fc.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + self.set_output_embeddings(out_embeds) + + if getattr(self.config, "tie_word_embeddings", False): + self.get_output_embeddings().additional_fc.weight = self.get_input_embeddings().additional_embedding.weight + self.initialized_flamingo = True self._use_cached_vision_x = False diff --git a/open_flamingo/src/helpers.py b/open_flamingo/src/helpers.py index 239503f8..dd44737b 100644 --- a/open_flamingo/src/helpers.py +++ b/open_flamingo/src/helpers.py @@ -6,6 +6,7 @@ from einops import rearrange, repeat from einops_exts import rearrange_many from torch import einsum, nn +from torch.nn import functional as F def exists(val): @@ -277,3 +278,190 @@ def forward( x = self.ff(x) * self.ff_gate.tanh() + x return x + + +# Both FlamingoDecoupledEmbedding and FlamingoDecoupledLinear are taken from https://github.com/huggingface/transformers/blob/v4.32.1/src/transformers/models/idefics/modeling_idefics.py and renamed for clarity +class FlamingoDecoupledEmbedding(nn.Embedding): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, + then it will create `num_additional_embeddings` additional parameters that are always trained. If + `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. + """ + + def __init__( + self, + num_embeddings, + num_additional_embeddings, + embedding_dim, + partially_freeze=True, + device=None, + dtype=None, + padding_idx=None, + **kwargs, + ) -> None: + """ + Args: + num_embeddings (`int`): + Size of the dictionary of embeddings + num_additional_embeddings (`int`): + Number of additional embeddings. Only useful when you `partially_freeze=True`. + embedding_dim (`int`): + The size of each embedding vector + partially_freeze: (`bool`, *optional*, defaults to `True`): + If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. + padding_idx (`int`, *optional*): + The padding index (needs to be less than num_embeddings) + + Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, + `max_norm` or `norm_type`. We are not supporting these. + """ + if padding_idx is not None and padding_idx > num_embeddings: + raise ValueError( + f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}" + ) + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + padding_idx=padding_idx, + **kwargs, + ) + self.num_embeddings = num_embeddings + self.padding_idx = padding_idx + self.num_additional_embeddings = num_additional_embeddings + self.partially_freeze = partially_freeze + + if partially_freeze: + self.weight.requires_grad_(False) + + if self.num_additional_embeddings > 0: + self.additional_embedding = nn.Embedding( + num_embeddings=self.num_additional_embeddings, + embedding_dim=embedding_dim, + device=device, + dtype=dtype, + ) + + def forward(self, input_ids): + """ + we have 2 embeddings, with different indices - one pretrained self.weight and another + self.additional_embedding.weight that is being trained. + + in order to make a lookup of the input ids, we: + 1. find out the indices of the entries belonging to the 2nd embedding + 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd + embedding starts from 0 and not num_embeddings + 3. perform the 2nd embedding lookup + 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index + 5. perform the 1st embedding lookup + 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup + + note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but + then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - + i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are + usually relatively short it's probably not faster or if faster not by much - but might be a good idea to + measure. + + """ + if self.num_additional_embeddings == 0: + return F.embedding(input_ids, self.weight) + + # Clone so that we don't modify the original input_ids later on + input_ids = input_ids.clone() + additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) + input_ids_additional_vocab = input_ids[additional_vocab_indices] + additional_embeddings = self.additional_embedding( + input_ids_additional_vocab - self.num_embeddings + ) + + # for successful lookup replace input_ids with 0, the results of these will be discarded anyway + input_ids[additional_vocab_indices] = 0 + full_vector = F.embedding(input_ids, self.weight) + + # overwrite the records with high indices + full_vector[additional_vocab_indices] = additional_embeddings + + return full_vector + + def extra_repr(self) -> str: + return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( + self.num_embeddings, + self.num_additional_embeddings, + self.embedding_dim, + self.partially_freeze, + ) + + +class FlamingoDecoupledLinear(nn.Linear): + # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear + """ + Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the + regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0, + then it will create `out_additional_features * in_features` additional parameters that are always trained. If + `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + out_additional_features: int = 0, + bias: bool = True, + partially_freeze: bool = True, + device=None, + dtype=None, + ) -> None: + """ + out_additional_features: int. Number of additional trainable dimensions. Only makes sense when + `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra + parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear. + """ + super().__init__(in_features, out_features, bias, device, dtype) + self.out_additional_features = out_additional_features + self.partially_freeze = partially_freeze + + self.in_features = in_features + self.out_features = out_features + + if partially_freeze: + self.weight.requires_grad_(False) + if bias: + self.bias.requires_grad_(False) + + if out_additional_features > 0: + self.additional_fc = nn.Linear( + in_features=in_features, + out_features=out_additional_features, + bias=bias, + device=device, + dtype=dtype, + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + output = F.linear(input, self.weight, self.bias) + + if self.out_additional_features > 0: + additional_features = F.linear( + input, self.additional_fc.weight, self.additional_fc.bias + ) + # Concatenate the additional features to the output if new vocab doesn't have a placeholder token in the original embedding + if self.weight.shape[0] < self.out_features + self.out_additional_features: + output = torch.cat((output, additional_features), dim=-1) + else: + # Otherwise, overwrite the placeholder tokens with the additional features + output[..., self.out_features:self.out_features + self.out_additional_features] = additional_features + + return output + + def extra_repr(self) -> str: + """Overwriting `nn.Linear.extra_repr` to include new parameters.""" + return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format( + self.in_features, + self.out_features, + self.out_additional_features, + self.bias is not None, + self.partially_freeze, + ) diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index be6a6b15..ef9d243b 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -122,11 +122,6 @@ def main(): help="we define an 'epoch' as a fixed number of examples (train_num_samples_mmc4, train_num_samples_laion), not a pass through the entire dataset", ) parser.add_argument("--offline", action="store_true") - parser.add_argument( - "--freeze_lm_embeddings", - action="store_true", - help="if True, we freeze the LM embeddings during training. Otherwise, we train the and <|endofchunk|> embeddings.", - ) parser.add_argument( "--logging_steps", type=int, default=100, help="log loss every n steps" ) @@ -237,8 +232,7 @@ def main(): ) args = parser.parse_args() - - args.local_rank = int(os.environ.get('LOCAL_RANK', -1)) # for deepspeed + args.local_rank, args.rank, args.world_size = world_info_from_env() # Validate args if args.laion_shards.startswith("s3"): @@ -253,8 +247,7 @@ def main(): if args.fsdp and not args.fsdp_use_orig_params: print( "Warning: FSDP is running without fsdp_use_orig_params flag. " - + "This is not recommended because it means we will use uniform weight decay" - + " and train all embeddings, not just the newly added ones. " + + "This is not recommended because it means we will use uniform weight decay." + "Note: OPT models are not compatible with fsdp_use_orig_params flag." ) @@ -275,7 +268,6 @@ def main(): os.environ["WANDB_MODE"] = "offline" os.environ["TRANSFORMERS_OFFLINE"] = "1" - args.local_rank, args.rank, args.world_size = world_info_from_env() if args.deepspeed: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) @@ -322,6 +314,9 @@ def main(): device_id = init_distributed_device(args) random_seed(args.seed) + + if args.fsdp: + print("Untying embeddings for FSDP") # Initialize model model, image_processor, tokenizer = create_model_and_transforms( @@ -330,9 +325,9 @@ def main(): args.lm_path, args.tokenizer_path if args.tokenizer_path else args.lm_path, cross_attn_every_n_layers=args.cross_attn_every_n_layers, + untie_embeddings=args.fsdp, # untie embeddings for FSDP use_local_files=args.offline, gradient_checkpointing=args.gradient_checkpointing, - freeze_lm_embeddings=args.freeze_lm_embeddings, ) random_seed(args.seed, args.rank) @@ -440,9 +435,31 @@ def main(): elif not args.deepspeed: model = model.to(device_id) ddp_model = DDP(model, device_ids=[device_id]) + + # Initialize gradient checkpointing + if args.gradient_checkpointing: + if args.deepspeed: + raise ValueError( + "gradient checkpointing currently not supported with deepspeed" + ) + non_reentrant_wrapper = functools.partial( + checkpoint_wrapper, + offload_to_cpu=True, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + apply_activation_checkpointing( + ddp_model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) + and not isinstance(m, FSDP) + and not isinstance(m, CheckpointWrapper), + ) # Initialize optimizer - params_to_optimize = model.named_parameters() + params_to_optimize = ( + ddp_model.named_parameters() if not args.deepspeed else model.named_parameters() + ) + params_to_optimize = list( filter( lambda x: x[1].requires_grad @@ -526,6 +543,9 @@ def get_grouped_params(model): print( f"After deepspeed {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}" ) + print( + f"After deepspeed parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" + ) if args.resume_from_checkpoint is not None: if args.rank == 0: @@ -541,25 +561,6 @@ def get_grouped_params(model): checkpoint_epoch = int(f.read().split("_")[-1]) resume_from_epoch = checkpoint_epoch + 1 - # Initialize gradient checkpointing - if args.gradient_checkpointing: - if args.deepspeed: - raise ValueError( - "gradient checkpointing currently not supported with deepspeed" - ) - non_reentrant_wrapper = functools.partial( - checkpoint_wrapper, - offload_to_cpu=True, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) - apply_activation_checkpointing( - ddp_model, - checkpoint_wrapper_fn=non_reentrant_wrapper, - check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) - and not isinstance(m, FSDP) - and not isinstance(m, CheckpointWrapper), - ) - for epoch in range(resume_from_epoch, args.num_epochs): laion_dataset.set_epoch(epoch) laion_loader = laion_dataset.dataloader diff --git a/open_flamingo/train/train_utils.py b/open_flamingo/train/train_utils.py index cea9ffe4..2be50a1a 100644 --- a/open_flamingo/train/train_utils.py +++ b/open_flamingo/train/train_utils.py @@ -72,9 +72,6 @@ def train_one_epoch( # setup model media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] - endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)[ - "input_ids" - ][-1] model.train() # setup logging @@ -166,31 +163,6 @@ def train_one_epoch( else: (divided_loss_mmc4 * args.loss_multiplier_mmc4).backward() - # TODO: investigate whether this is necessary - if (not args.freeze_lm_embeddings) and ( - not args.fsdp or args.fsdp_use_orig_params - ): - # Mask gradients for input embeddings s.t. we only update the added tokens and <|endofchunk|> - if args.fsdp or args.deepspeed: - embed_grad = model.lang_encoder.get_input_embeddings().weight.grad - else: - embed_grad = ( - model.module.lang_encoder.get_input_embeddings().weight.grad - ) - zero_mask = torch.zeros_like(embed_grad) - zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) - zero_mask[endofchunk_token_id] = torch.ones_like( - zero_mask[endofchunk_token_id] - ) - if args.fsdp or args.deepspeed: - model.lang_encoder.get_input_embeddings().weight.grad = ( - embed_grad * zero_mask - ) - else: - model.module.lang_encoder.get_input_embeddings().weight.grad = ( - embed_grad * zero_mask - ) - # clip gradient norm if args.fsdp: """ diff --git a/requirements-training.txt b/requirements-training.txt index 79ff0bc9..8b46a831 100644 --- a/requirements-training.txt +++ b/requirements-training.txt @@ -3,3 +3,4 @@ braceexpand webdataset tqdm wandb +deepspeed \ No newline at end of file