diff --git a/open_flamingo/src/blip.py b/open_flamingo/src/blip.py index 6b13c28e..722c7335 100644 --- a/open_flamingo/src/blip.py +++ b/open_flamingo/src/blip.py @@ -2,6 +2,7 @@ from .helpers import QFormerWithProjection from .vlm import VLMWithLanguageStream + class BLIP(VLMWithLanguageStream): def __init__( self, @@ -58,6 +59,3 @@ def set_trainable(self): def _should_apply_weight_decay(self, parameter_name): """BLIP applies 0.05 weight decay to everything""" return True - - def wrap_fsdp(self, wrapper_kwargs, device_id): - raise NotImplementedError \ No newline at end of file diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 3b980434..8b73a82d 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -23,6 +23,7 @@ def create_model_and_transforms( cache_dir: Optional[str] = None, gradient_checkpointing: bool = False, untie_embeddings: bool = False, + verbose: bool = True, **model_kwargs, ): """ @@ -40,6 +41,7 @@ def create_model_and_transforms( cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False. untie_embeddings (bool, optional): whether to untie the input and output embeddings of the language model. Defaults to False. + verbose (bool, optional): whether to print model info. Defaults to True. Returns: Flamingo: Flamingo model from pretrained vision and language encoders Image processor: Pipeline to preprocess input images @@ -79,11 +81,12 @@ def create_model_and_transforms( ) check_embedding_fns(lang_model) if untie_embeddings: + print("Untying language model embeddings...") lang_model.get_output_embeddings().weight = nn.Parameter( lang_model.get_output_embeddings().weight.clone() ) lang_model.config.update({"tie_word_embeddings": False}) - + # vocab sizes: note that lang_model.config.vocab_size is not necessarily = len(text_tokenizer) # the current input_embedding / output_embedding weights probably use lang_model.config.vocab_size # but the tokenizer will assign additional ids based on len(text_tokenizer) @@ -139,13 +142,16 @@ def create_model_and_transforms( } ) - # freeze appropraite parameters + # freeze appropriate parameters model.set_trainable() - print( - f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters" - ) - print(f"==========\n{model.num_trainable_params_per_module}") - print(f"==========\n{model.num_params_per_module}\n==========") + + # log model info + if verbose: + print( + f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters" + ) + print(f"==========\n{model.num_trainable_params_per_module}") + print(f"==========\n{model.num_params_per_module}\n==========") return model, image_processor, text_tokenizer diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index 938dcb8d..15829361 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -1,5 +1,5 @@ from torch import nn -from .helpers import PerceiverResampler +from .helpers import PerceiverResampler, GatedCrossAttentionBlock from .vlm import VLMWithCrossAttention @@ -61,113 +61,3 @@ def _should_apply_weight_decay(self, parameter_name): Flamingo applies 0.1 weight decay to cross attention parameters """ return "gated_cross_attn" in parameter_name - - def wrap_fsdp(self, wrapper_kwargs, device_id): - """ - Manually wraps submodules for FSDP and move other parameters to device_id. - - Why manually wrap? - - all parameters within the FSDP wrapper must have the same requires_grad. - We have a mix of frozen and unfrozen parameters. - - model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors - See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344 - - The rough wrapping structure is: - - FlamingoModel - - FSDP(FSDP(vision_encoder)) - - FSDP(FSDP(perceiver)) - - lang_model - - FSDP(FSDP(input_embeddings)) - - CrossAttentionLayers - - FSDP(FSDP(gated_cross_attn_layer)) - - FSDP(FSDP(decoder_layer)) - - FSDP(FSDP(output_embeddings)) - - other parameters - - Known issues: - - Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied, - train with DDP or set the --freeze_lm_embeddings flag to true. - - With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound. - Although the training curves look okay, we found that downstream performance dramatically - degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M). - - FAQs about our FSDP wrapping strategy: - Why double wrap? - As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook - only free gathered parameters if the module is NOT FSDP root. - """ - print( - "WARNING: FSDP is not designed for training with a mix of frozen and unfrozen parameters. " - + "This experimental workaround results in a significant drop in GPU power usage." - ) - - from torch.distributed.fsdp.wrap import ( - enable_wrap, - wrap, - ) - from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ) - from .utils import apply_with_stopping_condition - - # wrap in FSDP - with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs): - self.perceiver = wrap(wrap(self.perceiver)) - self.lang_model.old_decoder_blocks = nn.ModuleList( - wrap(wrap(block)) for block in self.lang_model.old_decoder_blocks - ) - self.lang_model.gated_cross_attn_layers = nn.ModuleList( - wrap(wrap(layer)) if layer is not None else None - for layer in self.lang_model.gated_cross_attn_layers - ) - self.lang_model.init_flamingo_layers(self._use_gradient_checkpointing) - if hasattr(self.lang_model.get_input_embeddings(), "additional_embedding"): - # wrap additional_embedding and original embedding separately, this is the case when using decoupled embeddings - self.lang_model.get_input_embeddings().additional_embedding = wrap( - wrap(self.lang_model.get_input_embeddings().additional_embedding) - ) - self.lang_model.get_input_embeddings().weight = wrap( - wrap(self.lang_model.get_input_embeddings().weight) - ) - else: - self.lang_model.set_input_embeddings( - wrap(wrap(self.lang_model.get_input_embeddings())) - ) - - if hasattr(self.lang_model.get_output_embeddings(), "additional_fc"): - # wrap additional_fc and original embedding separately, this is the case when using decoupled linear layer - self.lang_model.get_output_embeddings().additional_fc = wrap( - wrap(self.lang_model.get_output_embeddings().additional_fc) - ) - self.lang_model.get_output_embeddings().weight = wrap( - wrap(self.lang_model.get_output_embeddings().weight) - ) - if self.lang_model.get_output_embeddings().bias is not None: - self.lang_model.get_output_embeddings().bias = wrap( - wrap(self.lang_model.get_output_embeddings().bias) - ) - else: - self.lang_model.set_output_embeddings( - wrap(wrap(self.lang_model.get_output_embeddings())) - ) - - self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen - - # manually move non-FSDP managed parameters to device_id - # these are all in lang_model - apply_with_stopping_condition( - module=self.lang_model, - apply_fn=lambda m: m.to(device_id), - apply_condition=lambda m: len(list(m.children())) == 0, - stopping_condition=lambda m: isinstance(m, FSDP), - ) - - # set up clip_grad_norm_ function - def clip_grad_norm_(max_norm): - self.perceiver.clip_grad_norm_(max_norm) - for layer in self.lang_model.gated_cross_attn_layers: - if layer is not None: - layer.clip_grad_norm_(max_norm) - self.lang_model.get_input_embeddings().clip_grad_norm_(max_norm) - - self.clip_grad_norm_ = clip_grad_norm_ diff --git a/open_flamingo/src/kosmos.py b/open_flamingo/src/kosmos.py index 1ada1396..ef538291 100644 --- a/open_flamingo/src/kosmos.py +++ b/open_flamingo/src/kosmos.py @@ -50,6 +50,3 @@ def _should_apply_weight_decay(self, parameter_name): Kosmos applies 0.01 weight deacy to everything """ return True - - def wrap_fsdp(self, wrapper_kwargs, device_id): - raise NotImplementedError diff --git a/open_flamingo/src/vlm.py b/open_flamingo/src/vlm.py index d38c8037..4f4aee21 100644 --- a/open_flamingo/src/vlm.py +++ b/open_flamingo/src/vlm.py @@ -2,7 +2,7 @@ from einops import rearrange from torch import nn from typing import List, Optional, Tuple, Union -from .utils import extend_instance, stack_with_padding, num_params +from .utils import extend_instance, stack_with_padding, num_params, getattr_recursive from .cross_attn_lm import CrossAttentionMixin from .helpers import DecoupledEmbedding, DecoupledLinear, VLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast @@ -325,8 +325,12 @@ def set_special_token_ids(self, string_to_ids): setattr(self, f"{att_name}_id", token_id) setattr(self.lang_model, f"{att_name}_id", token_id) - def wrap_fsdp(self, wrapper_kwargs, device_id): - raise NotImplementedError + def get_fsdp_unsharded_params(self): + """ + Returns a list of parameters that should not be sharded by FSDP. + These will occupy GPU memory, but we'll save on communication costs. + """ + return [] def init_gradient_checkpointing(self): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -339,7 +343,6 @@ def init_gradient_checkpointing(self): non_reentrant_wrapper = partial( checkpoint_wrapper, - offload_to_cpu=True, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) apply_activation_checkpointing( @@ -439,6 +442,29 @@ def _post_forward_hook(self): # clear the conditioned layers self.lang_model.clear_conditioned_layers() + def get_fsdp_lambda_fn(self): + """ + Returns the lambda function used to decide how to perform FSDP wrapping. + """ + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, + ) + from .helpers import GatedCrossAttentionBlock + + original_decoder_block_class = self.lang_model.old_decoder_blocks[0].__class__ + + def lambda_fn(module: nn.Module): + if isinstance(module, CheckpointWrapper): + return False + if module is self.vision_tokenizer: + return True + if isinstance(module, GatedCrossAttentionBlock): + return True + if isinstance(module, original_decoder_block_class): + return True + + return lambda_fn + @property def num_params_per_module(self): """Print the number of parameters per module in the model""" @@ -480,6 +506,7 @@ def __init__( lang_model: nn.Module, initial_tokenizer_len: int, pad_token_id: int, + decoder_layers_attr_name: str = None, gradient_checkpointing: bool = False, ): super().__init__( @@ -491,6 +518,7 @@ def __init__( gradient_checkpointing=gradient_checkpointing, ) self.lang_model._use_gradient_checkpointing = gradient_checkpointing + self.decoder_layers_attr_name = decoder_layers_attr_name assert ( self.vis_embedding_dim == self.lang_embedding_dim ), "To place visual tokens direclty in the language stream, the visual and language tokens need to be the same dim." @@ -652,6 +680,28 @@ def _postprocess_outputs_from_forward( def _post_forward_hook(self): pass + def get_fsdp_lambda_fn(self): + """ + Returns the lambda function used to decide how to perform FSDP wrapping. + """ + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointWrapper, + ) + + decoder_block_class = getattr_recursive( + self.lang_model, self.decoder_layers_attr_name + )[0].__class__ + + def lambda_fn(module: nn.Module): + if isinstance(module, CheckpointWrapper): + return False + if module is self.vision_tokenizer: + return True + if isinstance(module, decoder_block_class): + return True + + return lambda_fn + @property def num_params_per_module(self): """Print the number of parameters per module in the model""" diff --git a/open_flamingo/train/README.md b/open_flamingo/train/README.md index ae086897..3e95e3e4 100644 --- a/open_flamingo/train/README.md +++ b/open_flamingo/train/README.md @@ -54,7 +54,10 @@ torchrun --nnodes=1 --nproc_per_node=4 train.py \ ## Distributed training By default, `train.py` uses Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html) for training. -To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), use the `--fsdp` flag. +To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), make sure to use Pytorch Nightly (> 2.0.1), and use the `--fsdp` flag. +To use [DeepSpeed](https://github.com/microsoft/DeepSpeed), use the `--deepspeed` flag. +(Note that you should use *either* FSDP or Deepspeed, not both.) + Some notes on FSDP: diff --git a/open_flamingo/train/data.py b/open_flamingo/train/data.py index 2066625f..92c231be 100644 --- a/open_flamingo/train/data.py +++ b/open_flamingo/train/data.py @@ -58,16 +58,18 @@ def filter_no_caption_or_no_image(sample): "png" in sample or "jpg" in sample or "jpeg" in sample ) + def preprocess_laion_image(sample, image_processor): """ - Preprocess image for LAION. + Preprocess image for LAION. Applied to a batch of images. """ sample = preprocess_image(sample, image_processor) return rearrange(sample, "(b t f) c h w -> b t f c h w", t=1, f=1) + def preprocess_laion_text(sample, tokenizer, max_tokens=32): """ - Preprocess text for LAION. + Preprocess text for LAION. Applied to a batch of captions. Captions are truncated to 32 tokens by default. """ tokenizer.padding_side = "right" @@ -159,6 +161,7 @@ def preprocess_interleaved( """ Preprocess an interleaved image-text sequence, either by calling preprocess_gpt_interleaved (if the sequence is ChatGPT-generated) or by preprocessing in this function (if the sequences is from MMC4). + Applied to a single sequence. """ info = json.loads(sample[0]) if "is_gpt" in info: @@ -244,7 +247,7 @@ def preprocess_interleaved( elif ( num_images == 1 and random.random() <= 0.5 ): # 50% chance of keeping single image samples - raise ValueError("Only one image in sample") + raise ValueError("Only one images in sample") # avoid the situation where there's one token and it's at the end if ( @@ -259,7 +262,7 @@ def preprocess_interleaved( ) return ( - rearrange(images_tensors, "b (t f) c h w -> b t f c h w", f=1), + rearrange(images_tensors, "(t f) c h w -> t f c h w", f=1), (text_tensor["input_ids"], text_tensor["attention_mask"]), ) @@ -326,11 +329,17 @@ def get_mmc4_dataset(args, image_processor, tokenizer, epoch=0, floor=False): ] ) + def zip_text(batch): + """Unpack from [(input_ids, attention_mask), ...] to (input_ids, attention_mask)""" + input_ids, attention_mask = tuple(zip(*batch)) + return torch.stack(input_ids).squeeze(1), torch.stack(attention_mask).squeeze(1) + pipeline.extend( [ wds.to_tuple("json", handler=log_and_continue), wds.map(preprocess_fn, handler=log_and_continue), wds.batched(args.batch_size_mmc4, partial=False), + wds.map_tuple(lambda x: x, zip_text, handler=log_and_continue), ] ) @@ -399,7 +408,9 @@ def get_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False): pipeline = [wds.SimpleShardList(input_shards)] # create two preprocess functions that take in the passed in image_processor and tokenizer - preprocess_image_fn = functools.partial(preprocess_laion_image, image_processor=image_processor) + preprocess_image_fn = functools.partial( + preprocess_laion_image, image_processor=image_processor + ) preprocess_text_fn = functools.partial(preprocess_laion_text, tokenizer=tokenizer) # at this point we have an iterator over all the shards @@ -490,4 +501,4 @@ def get_data(args, image_processor, tokenizer, dataset_type, epoch=0): args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer ) else: - raise ValueError(f"Unsupported dataset: {dataset_type}") \ No newline at end of file + raise ValueError(f"Unsupported dataset: {dataset_type}") diff --git a/open_flamingo/train/data_utils.py b/open_flamingo/train/data_utils.py index 0ed961ed..d20ee3b0 100644 --- a/open_flamingo/train/data_utils.py +++ b/open_flamingo/train/data_utils.py @@ -45,11 +45,11 @@ class DataInfo: """ DataInfo is a dataclass that holds information about a dataset. """ + name: str dataloader: DataLoader batch_size: int loss_multiplier: int - max_tokens: int sampler: DistributedSampler = None shared_epoch: SharedEpoch = None @@ -93,7 +93,8 @@ def get_dataset_size(shards): def log_and_continue(exn): """Call in an exception handler to ignore any exception, issue a warning, and continue.""" - logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + if "images in sample" not in repr(exn): + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") return True @@ -228,4 +229,4 @@ def __iter__(self): seed = self.worker_seed() + epoch self.rng.seed(seed) for _ in range(self.nshards): - yield dict(url=self.rng.choice(self.urls)) \ No newline at end of file + yield dict(url=self.rng.choice(self.urls)) diff --git a/open_flamingo/train/distributed.py b/open_flamingo/train/distributed.py index ce7cd485..293a74da 100644 --- a/open_flamingo/train/distributed.py +++ b/open_flamingo/train/distributed.py @@ -1,10 +1,14 @@ """ Util functions for distributed training, FSDP, and Deepspeed. -Credit: https://github.com/mlfoundations/open_clip/blob/main/src/training/distributed.py """ import os import torch +from data import SUPPORTED_DATASETS + +################################## +# SLURM setup; Credit: open_clip # +################################## try: import horovod.torch as hvd @@ -132,6 +136,11 @@ def init_distributed_device(args): return device +##################################### +# FSDP and Deepspeed util functions # +##################################### + + def get_fsdp_mixed_precision_policy( precision: str, reduce_param_precision=False, @@ -176,30 +185,14 @@ def get_fsdp_config( reduce_buffer_precision=True, ) - # init process groups - from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups - from torch.distributed.distributed_c10d import _get_default_group - - if args.fsdp_sharding_strategy == "hybrid": - intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( - _get_default_group() - ) - args.my_group = intra_node_group # for optimizer saving - process_group = (intra_node_group, inter_node_group) # for FSDP init - else: - args.my_group = None # for optimizer saving - process_group = None # for FSDP init - # init FSDP from torch.distributed.fsdp import ( - CPUOffload, ShardingStrategy, BackwardPrefetch, ) return dict( - process_group=process_group, - cpu_offload=CPUOffload(offload_params=False), + cpu_offload=None, device_id=device_id, sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks sharding_strategy=ShardingStrategy.FULL_SHARD @@ -213,9 +206,33 @@ def get_fsdp_config( ) +def get_fsdp_checkpoint_config(args): + """ + Return kwargs for FSDP checkpointing. + """ + from torch.distributed.fsdp import ( + FullStateDictConfig, + StateDictType, + ) + from torch.distributed.fsdp.api import FullOptimStateDictConfig + + # to avoid GPU OOM when loading/saving ckpts, load/save on CPU + # this is slow + return dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True), + optim_state_dict_config=FullOptimStateDictConfig( + rank0_only=True, offload_to_cpu=True + ), + ) + + def get_deepspeed_config( args, ): + """ + Return kwargs for Deepspeed config. + """ zero_opt_dict = { "stage": args.deepspeed_stage, "overlap_comm": True, @@ -227,11 +244,15 @@ def get_deepspeed_config( "stage3_prefetch_bucket_size": 3e7, "memory_efficient_linear": False, } + # sum all the args that start with batch_size_ to get the total batch size + total_batch_size = sum( + [getattr(args, arg) for arg in vars(args) if arg.startswith("batch_size_")] + ) ds_config = { - "train_batch_size": (args.batch_size_mmc4 + args.batch_size_laion) + "train_batch_size": total_batch_size * args.world_size * args.gradient_accumulation_steps, - "train_micro_batch_size_per_gpu": (args.batch_size_mmc4 + args.batch_size_laion) + "train_micro_batch_size_per_gpu": total_batch_size * args.gradient_accumulation_steps, "steps_per_print": args.logging_steps, "zero_optimization": zero_opt_dict, @@ -249,4 +270,3 @@ def get_deepspeed_config( raise ValueError("amp not supported with DeepSpeed") return ds_config - diff --git a/open_flamingo/train/losses.py b/open_flamingo/train/losses.py index ffc859a5..0133fc90 100644 --- a/open_flamingo/train/losses.py +++ b/open_flamingo/train/losses.py @@ -1,6 +1,17 @@ from open_flamingo.src.vlm import VLM import torch +SUPPORTED_LOSSES = ["next_token_prediction"] + + +def get_loss_fn(loss_name): + if loss_name == "next_token_prediction": + return NextTokenPrediction() + else: + raise ValueError( + f"Loss {loss_name} not supported. Supported losses: {SUPPORTED_LOSSES}" + ) + class Loss: @property @@ -48,9 +59,10 @@ def __call__( labels = input_ids.clone() labels[labels == tokenizer.pad_token_id] = -100 labels[labels == tokenizer.eos_token] = -100 - labels[ - torch.isin(labels, torch.Tensor(unwrap_model(model).special_token_ids)) - ] = -100 + special_token_ids = torch.Tensor(unwrap_model(model).special_token_ids).to( + labels.device + ) + labels[torch.isin(labels, special_token_ids)] = -100 labels = labels.to(input_ids.device) # call forward diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index e3d7c5a1..dd925c3d 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -4,8 +4,10 @@ import torch import wandb import deepspeed +import functools from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy from open_flamingo import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES from data import get_data, SUPPORTED_DATASETS @@ -13,6 +15,7 @@ init_distributed_device, world_info_from_env, get_fsdp_config, + get_fsdp_checkpoint_config, get_deepspeed_config, ) from train_utils import ( @@ -25,7 +28,8 @@ save_deepspeed_checkpoint, ) from losses import ( - NextTokenPrediction, + SUPPORTED_LOSSES, + get_loss_fn, ) from transformers import ( get_constant_schedule_with_warmup, @@ -57,6 +61,9 @@ def main(): ) # training args + parser.add_argument( + "--loss", type=str, choices=SUPPORTED_LOSSES, default="next_token_prediction" + ) parser.add_argument( "--run_name", type=str, @@ -122,6 +129,7 @@ def main(): parser.add_argument( f"--{dataset_name}_shards", type=str, + default=None, help="Should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar. If None, we will not train on this dataset.", ) parser.add_argument("--workers", type=int, default=1) @@ -212,17 +220,22 @@ def main(): args = parser.parse_args() + # Parse which datasets to train on and which to exclude + datasets_to_train_on = [] + for dataset_name in SUPPORTED_DATASETS: + if getattr(args, f"{dataset_name}_shards") is None: + setattr(args, f"train_num_samples_{dataset_name}", 0) + setattr(args, f"batch_size_{dataset_name}", 0) + else: + datasets_to_train_on.append(dataset_name) + assert len(datasets_to_train_on) > 0, "Must train on at least one dataset" + # Validate args for dataset_name in SUPPORTED_DATASETS: shards_path = getattr(args, f"{dataset_name}_shards") if shards_path is not None and shards_path.startswith("s3"): args.laion_shards = f"pipe:aws s3 cp {args.laion_shards} -" - datasets_to_train_on = [ - dataset_name - for dataset_name in SUPPORTED_DATASETS - if getattr(args, f"{dataset_name}_shards") is not None - ] for i in range(len(datasets_to_train_on) - 1): assert getattr(args, f"train_num_samples_{datasets_to_train_on[i]}") // getattr( args, f"batch_size_{datasets_to_train_on[i]}" @@ -240,7 +253,7 @@ def main(): if args.fsdp: print( - "Warning: FSDP is experimental and not fully supported. Preference should be given to Deepspeed." + "Warning: FSDP is experimental and not fully tested. Preference should be given to Deepspeed." ) assert ( "dev" in torch.__version__ and torch.__version__ > "2.0.1" @@ -248,6 +261,8 @@ def main(): # Set up distributed training args.local_rank, args.rank, args.world_size = world_info_from_env() + if args.rank == 0: + print(f"Initializing distributed training with {args.world_size} GPUs.") if args.offline: os.environ["WANDB_MODE"] = "offline" os.environ["TRANSFORMERS_OFFLINE"] = "1" @@ -273,15 +288,15 @@ def main(): args.lm_path, args.tokenizer_path if args.tokenizer_path else args.lm_path, model_family=args.model_family, - untie_embeddings=args.fsdp, # untie embeddings for FSDP + untie_embeddings=False, # untie embeddings for FSDP use_local_files=args.offline, gradient_checkpointing=args.gradient_checkpointing, + verbose=(args.rank == 0), **additional_kwargs, ) random_seed(args.seed, args.rank) # Initialize wandb logging - print(f"Start running training on rank {args.rank}.") if args.rank == 0 and args.report_to_wandb: wandb.init( project=args.wandb_project, @@ -291,33 +306,47 @@ def main(): ) # Load model checkpoint (on CPU) + if args.fsdp: + args.fsdp_checkpoint_config = get_fsdp_checkpoint_config(args) + + # if args do not specify a checkpoint to resume from, resume from most recent checkpoint if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None: - # if args do not specify a checkpoint to resume from, resume from most recent checkpoint args.resume_from_checkpoint = find_most_recent_checkpoint(args) + if ( args.resume_from_checkpoint is not None and not args.deepspeed ): # deepspeed handles checkpoint loading resume_from_epoch, checkpoint = load_checkpoint(args, model) - - # Initialize FSDP / DDP, and ensure the model is on GPU - print(f"Initializing distributed training with {args.world_size} GPUs.") - if args.fsdp: - model.wrap_fsdp( - get_fsdp_config(args, device_id), device_id - ) # moves model to device_id - ddp_model = model - elif not args.deepspeed: - model = model.to(device_id) - ddp_model = DDP(model, device_ids=[device_id]) + else: + resume_from_epoch = 0 # Initialize gradient checkpointing if args.gradient_checkpointing: if args.deepspeed: raise ValueError( - "gradient checkpointing currently not supported with deepspeed" + "Gradient checkpointing currently not supported with deepspeed" ) model.init_gradient_checkpointing() + # Initialize FSDP / DDP, and ensure the model is on GPU + # Deepspeed is initialized later + if args.fsdp: + auto_wrap_policy = functools.partial( + lambda_auto_wrap_policy, lambda_fn=model.get_fsdp_lambda_fn() + ) + wrapper_kwargs = get_fsdp_config(args, device_id) + # to save on communication, we may choose to not shard some params + unsharded_params = model.get_fsdp_unsharded_params() + for p in unsharded_params: + p = p.to(device_id) + wrapper_kwargs["ignored_states"] = unsharded_params + distributed_model = FSDP( + model, auto_wrap_policy=auto_wrap_policy, **wrapper_kwargs + ) + elif not args.deepspeed: + model = model.to(device_id) + distributed_model = DDP(model, device_ids=[device_id]) + # Initialize optimizer params_with_wd, params_without_wd = model.group_params_by_weight_decay() optimizer = torch.optim.AdamW( @@ -332,7 +361,13 @@ def main(): if args.resume_from_checkpoint is not None and not args.deepspeed: osd = checkpoint["optimizer_state_dict"] if args.fsdp: - osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer) + FSDP.set_state_dict_type( + distributed_model, + **args.fsdp_checkpoint_config, + ) + osd = FSDP.optim_state_dict_to_load( + model=distributed_model, optim=optimizer, optim_state_dict=osd + ) optimizer.load_state_dict(osd) # Initialize datasets @@ -370,7 +405,7 @@ def main(): lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) if args.deepspeed: - ddp_model, optimizer, _, lr_scheduler = deepspeed.initialize( + distributed_model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model, optimizer=optimizer, args=args, @@ -379,18 +414,22 @@ def main(): dist_init_required=True, ) if args.resume_from_checkpoint is not None: - resume_from_epoch = load_deepspeed_checkpoint(args, ddp_model) + resume_from_epoch = load_deepspeed_checkpoint(args, distributed_model) + + # Initialize the loss fn + loss_fn = get_loss_fn(args.loss) # Start training! + print(f"Start running training on rank {args.rank}.") for epoch in range(resume_from_epoch, args.num_epochs): for dataset in datasets: dataset.set_epoch(epoch) train_one_epoch( args=args, - model=ddp_model, + model=distributed_model, epoch=epoch, datasets=datasets, - compute_loss_fn=NextTokenPrediction(), + compute_loss_fn=loss_fn, tokenizer=tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler, @@ -399,9 +438,9 @@ def main(): ) if args.deepspeed: - save_deepspeed_checkpoint(ddp_model, epoch, args) + save_deepspeed_checkpoint(distributed_model, epoch, args) else: - save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args) + save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args) if __name__ == "__main__": diff --git a/open_flamingo/train/train_utils.py b/open_flamingo/train/train_utils.py index 10c0bc8a..513b877d 100644 --- a/open_flamingo/train/train_utils.py +++ b/open_flamingo/train/train_utils.py @@ -3,11 +3,6 @@ import torch from tqdm import tqdm from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp import ( - FullStateDictConfig, - StateDictType, -) -from torch.distributed.fsdp.api import FullOptimStateDictConfig import os import shutil import wandb @@ -32,6 +27,18 @@ def train_one_epoch( """ Helper function for running one epoch of training. Handles logging, calling forward, backward, gradient clipping, and optimizer step. + Args: + args (argparse.Namespace): arguments from command line + model: DDP / FSDP / Deepspeed wrapped model + epoch (int): epoch number + datasets (list): list of DataInfos, one for each dataset, to train on + compute_loss_fn (callable): function that given the model and inputs, calls forward + and returns a loss + tokenizer: tokenizer for the language model + optimizer: optimizer to step + lr_scheduler: learning rate scheduler + device_id (int): GPU device ID for this rank + wandb: wandb object for logging """ # calculate the number of steps in an epoch num_batches_per_epoch = datasets[0].dataloader.num_batches @@ -60,11 +67,13 @@ def train_one_epoch( # call compute_loss_fn on each dataset; call backward before continuing losses_to_log = {} batch_metadata_to_log = {} - for dataset_ix, batch in enumerate(batches): + for dataset_ix, (images, (input_ids, attention_mask)) in enumerate(batches): + print(">> Dataset: ", datasets[dataset_ix].name, "Step: ", step_num) + # unpack the batch and move to device - images = batch[0].to(device_id, dtype=cast_dtype, non_blocking=True) - input_ids = batch[1][0].to(device_id, non_blocking=True) - attention_mask = batch[1][1].to(device_id, non_blocking=True) + images = images.to(device_id, dtype=cast_dtype, non_blocking=True) + input_ids = input_ids.to(device_id, non_blocking=True) + attention_mask = attention_mask.to(device_id, non_blocking=True) # save some metadata for logging batch_metadata_to_log[ @@ -151,14 +160,14 @@ def train_one_epoch( # Log loss to console if ((step_num + 1) % args.logging_steps == 0) and args.rank == 0: print( - f"Step {step_num+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete." - + "//".join([f"{k}: {v:.3f}" for k, v in losses_to_log]) + f"Step {step_num+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Losses: " + + "// ".join([f"{k}: {v:.3f}" for k, v in losses_to_log.items()]) ) def get_cast_dtype(precision: str): """ - Returns the dtype to cast inputs to for a given precision. + Parses the precision argument and returns the dtype to cast inputs to. """ cast_dtype = None if precision == "bf16": @@ -169,6 +178,9 @@ def get_cast_dtype(precision: str): def get_autocast(precision, cache_enabled=True): + """ + Parses the precision argument and returns an autocast context manager. + """ if precision == "amp": return torch.cuda.amp.autocast(cache_enabled=cache_enabled) elif precision == "amp_bfloat16" or precision == "amp_bf16": @@ -178,16 +190,20 @@ def get_autocast(precision, cache_enabled=True): ) else: return suppress - + + def random_seed(seed=42, rank=0): + """Seed everything""" torch.manual_seed(seed + rank) np.random.seed(seed + rank) random.seed(seed + rank) + ################################ # Helper functions for logging # ################################ + class AverageMeter(object): """Computes and stores the average and current value""" @@ -206,12 +222,16 @@ def update(self, val, n=1): self.count += n self.avg = self.sum / self.count + def compute_throughput( args, datasets, batch_metadata, step_time_m, ): + """ + Computes throughput metrics for logging, including samples per second and tokens per second. + """ log = {} for dataset in datasets: log[f"{dataset.name}_samples_per_second_per_gpu"] = ( @@ -244,6 +264,7 @@ def compute_throughput( # Helper functions for checkpoint loading / saving # #################################################### + def find_most_recent_checkpoint(args): """ Returns the path of the most recent checkpoint for a given run name. @@ -265,22 +286,29 @@ def find_most_recent_checkpoint(args): print(f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}.") return resume_from_checkpoint + def load_checkpoint(args, model): - """Loads a (non-Deepspeed) checkpoint and returns the checkpoint + epoch to resume from.""" + """ + Loads a (non-Deepspeed) checkpoint into the model and returns the checkpoint + epoch to resume from. + Does not load the optimizer or learning rate checkpoints, but these are included in the returned checkpoint dict. + """ if args.rank == 0: print(f"Loading checkpoint from {args.resume_from_checkpoint}") checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") msd = checkpoint.pop("model_state_dict") msd = {k.replace("module.", ""): v for k, v in msd.items()} resume_from_epoch = checkpoint["epoch"] + 1 - - # for fsdp, only one rank needs to load the state dict - if not args.fsdp or args.rank == 0: - model.load_state_dict(msd, False) + if args.fsdp: + FSDP.set_state_dict_type( + model, + **args.fsdp_checkpoint_config, + ) + model.load_state_dict(msd, False) return resume_from_epoch, checkpoint + def load_deepspeed_checkpoint(args, ddp_model): - """Loads a deepspeed checkpoint and returns the epoch to resume from.""" + """Loads a deepspeed checkpoint and returns the epoch to resume from.""" if args.rank == 0: print(f"Loading checkpoint from {args.resume_from_checkpoint}") # We will not pass in a 'tag' and instead rely on 'latest' file in the checkpoint directory @@ -302,34 +330,25 @@ def filter_state_dict_to_trainable(model, state_dict): This is because we need the new <|endofchunk|> tokens to be consistent across initializations. """ - for ( - name, - p, - ) in model.named_parameters(): # won't work for fsdp + use_orig_params=False + # first, remove frozen params + for name, p in model.named_parameters(): if "fsdp" in name: continue - if "embed" in name or isinstance(p, torch.nn.Embedding): - continue - if not p.requires_grad: + if not p.requires_grad or to_delete(name): name = name.replace("._checkpoint_wrapped_module", "") if name in state_dict: del state_dict[name] else: print(f"WARNING: filtering but {name} not in state_dict") - - # also remove the keys in state_dict generated from - # lang_encoder.old_decoder_blocks and lang_encoder.gated_cross_attn_layers - # because these are already saved in lang_encoder.model... - to_delete = [ - n - for n in state_dict.keys() - if ("lang_encoder.old_decoder_blocks" in n) - or ("lang_encoder.gated_cross_attn_layers" in n) - or ("vision_encoder" in n) - ] - for name in to_delete: - del state_dict[name] - return state_dict + # second, remove additional duplicate params + duplicate = lambda k: ( + "lang_model.old_decoder_blocks" in k + or "lang_model.gated_cross_attn_layers" in k + ) + filtered_dict = { + key: value for key, value in state_dict.items() if not duplicate(key) + } + return filtered_dict def save_checkpoint(model, optimizer, lr_scheduler, epoch, args): @@ -339,12 +358,10 @@ def save_checkpoint(model, optimizer, lr_scheduler, epoch, args): if args.fsdp: FSDP.set_state_dict_type( model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(rank0_only=True, offload_to_cpu=True), - FullOptimStateDictConfig(rank0_only=True), + **args.fsdp_checkpoint_config, ) model_state = model.state_dict() - optim_state = FSDP.optim_state_dict(model, optimizer, group=args.my_group) + optim_state = FSDP.optim_state_dict(model, optimizer) else: model_state = model.state_dict() optim_state = optimizer.state_dict()