Skip to content

Commit

Permalink
many fixes + rewrite FSDP for torch nightly
Browse files Browse the repository at this point in the history
  • Loading branch information
i-gao committed Sep 16, 2023
1 parent 11ab894 commit cd4f3aa
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 235 deletions.
4 changes: 1 addition & 3 deletions open_flamingo/src/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .helpers import QFormerWithProjection
from .vlm import VLMWithLanguageStream


class BLIP(VLMWithLanguageStream):
def __init__(
self,
Expand Down Expand Up @@ -58,6 +59,3 @@ def set_trainable(self):
def _should_apply_weight_decay(self, parameter_name):
"""BLIP applies 0.05 weight decay to everything"""
return True

def wrap_fsdp(self, wrapper_kwargs, device_id):
raise NotImplementedError
20 changes: 13 additions & 7 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def create_model_and_transforms(
cache_dir: Optional[str] = None,
gradient_checkpointing: bool = False,
untie_embeddings: bool = False,
verbose: bool = True,
**model_kwargs,
):
"""
Expand All @@ -40,6 +41,7 @@ def create_model_and_transforms(
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
untie_embeddings (bool, optional): whether to untie the input and output embeddings of the language model. Defaults to False.
verbose (bool, optional): whether to print model info. Defaults to True.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Expand Down Expand Up @@ -79,11 +81,12 @@ def create_model_and_transforms(
)
check_embedding_fns(lang_model)
if untie_embeddings:
print("Untying language model embeddings...")
lang_model.get_output_embeddings().weight = nn.Parameter(
lang_model.get_output_embeddings().weight.clone()
)
lang_model.config.update({"tie_word_embeddings": False})

# vocab sizes: note that lang_model.config.vocab_size is not necessarily = len(text_tokenizer)
# the current input_embedding / output_embedding weights probably use lang_model.config.vocab_size
# but the tokenizer will assign additional ids based on len(text_tokenizer)
Expand Down Expand Up @@ -139,13 +142,16 @@ def create_model_and_transforms(
}
)

# freeze appropraite parameters
# freeze appropriate parameters
model.set_trainable()
print(
f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters"
)
print(f"==========\n{model.num_trainable_params_per_module}")
print(f"==========\n{model.num_params_per_module}\n==========")

# log model info
if verbose:
print(
f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters"
)
print(f"==========\n{model.num_trainable_params_per_module}")
print(f"==========\n{model.num_params_per_module}\n==========")
return model, image_processor, text_tokenizer


Expand Down
112 changes: 1 addition & 111 deletions open_flamingo/src/flamingo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import nn
from .helpers import PerceiverResampler
from .helpers import PerceiverResampler, GatedCrossAttentionBlock
from .vlm import VLMWithCrossAttention


Expand Down Expand Up @@ -61,113 +61,3 @@ def _should_apply_weight_decay(self, parameter_name):
Flamingo applies 0.1 weight decay to cross attention parameters
"""
return "gated_cross_attn" in parameter_name

def wrap_fsdp(self, wrapper_kwargs, device_id):
"""
Manually wraps submodules for FSDP and move other parameters to device_id.
Why manually wrap?
- all parameters within the FSDP wrapper must have the same requires_grad.
We have a mix of frozen and unfrozen parameters.
- model.vision_encoder.visual needs to be individually wrapped or encode_vision_x errors
See: https://github.com/pytorch/pytorch/issues/82461#issuecomment-1269136344
The rough wrapping structure is:
- FlamingoModel
- FSDP(FSDP(vision_encoder))
- FSDP(FSDP(perceiver))
- lang_model
- FSDP(FSDP(input_embeddings))
- CrossAttentionLayers
- FSDP(FSDP(gated_cross_attn_layer))
- FSDP(FSDP(decoder_layer))
- FSDP(FSDP(output_embeddings))
- other parameters
Known issues:
- Our FSDP strategy is not compatible with tied embeddings. If the LM embeddings are tied,
train with DDP or set the --freeze_lm_embeddings flag to true.
- With FSDP + gradient ckpting, one can increase the batch size with seemingly no upper bound.
Although the training curves look okay, we found that downstream performance dramatically
degrades if the batch size is unreasonably large (e.g., 100 MMC4 batch size for OPT-125M).
FAQs about our FSDP wrapping strategy:
Why double wrap?
As of torch==2.0.1, FSDP's _post_forward_hook and _post_backward_hook
only free gathered parameters if the module is NOT FSDP root.
"""
print(
"WARNING: FSDP is not designed for training with a mix of frozen and unfrozen parameters. "
+ "This experimental workaround results in a significant drop in GPU power usage."
)

from torch.distributed.fsdp.wrap import (
enable_wrap,
wrap,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
)
from .utils import apply_with_stopping_condition

# wrap in FSDP
with enable_wrap(wrapper_cls=FSDP, **wrapper_kwargs):
self.perceiver = wrap(wrap(self.perceiver))
self.lang_model.old_decoder_blocks = nn.ModuleList(
wrap(wrap(block)) for block in self.lang_model.old_decoder_blocks
)
self.lang_model.gated_cross_attn_layers = nn.ModuleList(
wrap(wrap(layer)) if layer is not None else None
for layer in self.lang_model.gated_cross_attn_layers
)
self.lang_model.init_flamingo_layers(self._use_gradient_checkpointing)
if hasattr(self.lang_model.get_input_embeddings(), "additional_embedding"):
# wrap additional_embedding and original embedding separately, this is the case when using decoupled embeddings
self.lang_model.get_input_embeddings().additional_embedding = wrap(
wrap(self.lang_model.get_input_embeddings().additional_embedding)
)
self.lang_model.get_input_embeddings().weight = wrap(
wrap(self.lang_model.get_input_embeddings().weight)
)
else:
self.lang_model.set_input_embeddings(
wrap(wrap(self.lang_model.get_input_embeddings()))
)

if hasattr(self.lang_model.get_output_embeddings(), "additional_fc"):
# wrap additional_fc and original embedding separately, this is the case when using decoupled linear layer
self.lang_model.get_output_embeddings().additional_fc = wrap(
wrap(self.lang_model.get_output_embeddings().additional_fc)
)
self.lang_model.get_output_embeddings().weight = wrap(
wrap(self.lang_model.get_output_embeddings().weight)
)
if self.lang_model.get_output_embeddings().bias is not None:
self.lang_model.get_output_embeddings().bias = wrap(
wrap(self.lang_model.get_output_embeddings().bias)
)
else:
self.lang_model.set_output_embeddings(
wrap(wrap(self.lang_model.get_output_embeddings()))
)

self.vision_encoder = wrap(wrap(self.vision_encoder)) # frozen

# manually move non-FSDP managed parameters to device_id
# these are all in lang_model
apply_with_stopping_condition(
module=self.lang_model,
apply_fn=lambda m: m.to(device_id),
apply_condition=lambda m: len(list(m.children())) == 0,
stopping_condition=lambda m: isinstance(m, FSDP),
)

# set up clip_grad_norm_ function
def clip_grad_norm_(max_norm):
self.perceiver.clip_grad_norm_(max_norm)
for layer in self.lang_model.gated_cross_attn_layers:
if layer is not None:
layer.clip_grad_norm_(max_norm)
self.lang_model.get_input_embeddings().clip_grad_norm_(max_norm)

self.clip_grad_norm_ = clip_grad_norm_
3 changes: 0 additions & 3 deletions open_flamingo/src/kosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,3 @@ def _should_apply_weight_decay(self, parameter_name):
Kosmos applies 0.01 weight deacy to everything
"""
return True

def wrap_fsdp(self, wrapper_kwargs, device_id):
raise NotImplementedError
58 changes: 54 additions & 4 deletions open_flamingo/src/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from einops import rearrange
from torch import nn
from typing import List, Optional, Tuple, Union
from .utils import extend_instance, stack_with_padding, num_params
from .utils import extend_instance, stack_with_padding, num_params, getattr_recursive
from .cross_attn_lm import CrossAttentionMixin
from .helpers import DecoupledEmbedding, DecoupledLinear, VLMOutputWithPast
from transformers.modeling_outputs import CausalLMOutputWithPast
Expand Down Expand Up @@ -325,8 +325,12 @@ def set_special_token_ids(self, string_to_ids):
setattr(self, f"{att_name}_id", token_id)
setattr(self.lang_model, f"{att_name}_id", token_id)

def wrap_fsdp(self, wrapper_kwargs, device_id):
raise NotImplementedError
def get_fsdp_unsharded_params(self):
"""
Returns a list of parameters that should not be sharded by FSDP.
These will occupy GPU memory, but we'll save on communication costs.
"""
return []

def init_gradient_checkpointing(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand All @@ -339,7 +343,6 @@ def init_gradient_checkpointing(self):

non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=True,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
Expand Down Expand Up @@ -439,6 +442,29 @@ def _post_forward_hook(self):
# clear the conditioned layers
self.lang_model.clear_conditioned_layers()

def get_fsdp_lambda_fn(self):
"""
Returns the lambda function used to decide how to perform FSDP wrapping.
"""
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
from .helpers import GatedCrossAttentionBlock

original_decoder_block_class = self.lang_model.old_decoder_blocks[0].__class__

def lambda_fn(module: nn.Module):
if isinstance(module, CheckpointWrapper):
return False
if module is self.vision_tokenizer:
return True
if isinstance(module, GatedCrossAttentionBlock):
return True
if isinstance(module, original_decoder_block_class):
return True

return lambda_fn

@property
def num_params_per_module(self):
"""Print the number of parameters per module in the model"""
Expand Down Expand Up @@ -480,6 +506,7 @@ def __init__(
lang_model: nn.Module,
initial_tokenizer_len: int,
pad_token_id: int,
decoder_layers_attr_name: str = None,
gradient_checkpointing: bool = False,
):
super().__init__(
Expand All @@ -491,6 +518,7 @@ def __init__(
gradient_checkpointing=gradient_checkpointing,
)
self.lang_model._use_gradient_checkpointing = gradient_checkpointing
self.decoder_layers_attr_name = decoder_layers_attr_name
assert (
self.vis_embedding_dim == self.lang_embedding_dim
), "To place visual tokens direclty in the language stream, the visual and language tokens need to be the same dim."
Expand Down Expand Up @@ -652,6 +680,28 @@ def _postprocess_outputs_from_forward(
def _post_forward_hook(self):
pass

def get_fsdp_lambda_fn(self):
"""
Returns the lambda function used to decide how to perform FSDP wrapping.
"""
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)

decoder_block_class = getattr_recursive(
self.lang_model, self.decoder_layers_attr_name
)[0].__class__

def lambda_fn(module: nn.Module):
if isinstance(module, CheckpointWrapper):
return False
if module is self.vision_tokenizer:
return True
if isinstance(module, decoder_block_class):
return True

return lambda_fn

@property
def num_params_per_module(self):
"""Print the number of parameters per module in the model"""
Expand Down
5 changes: 4 additions & 1 deletion open_flamingo/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ torchrun --nnodes=1 --nproc_per_node=4 train.py \
## Distributed training

By default, `train.py` uses Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html) for training.
To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), use the `--fsdp` flag.
To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), make sure to use Pytorch Nightly (> 2.0.1), and use the `--fsdp` flag.
To use [DeepSpeed](https://github.com/microsoft/DeepSpeed), use the `--deepspeed` flag.
(Note that you should use *either* FSDP or Deepspeed, not both.)


Some notes on FSDP:

Expand Down
23 changes: 17 additions & 6 deletions open_flamingo/train/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,18 @@ def filter_no_caption_or_no_image(sample):
"png" in sample or "jpg" in sample or "jpeg" in sample
)


def preprocess_laion_image(sample, image_processor):
"""
Preprocess image for LAION.
Preprocess image for LAION. Applied to a batch of images.
"""
sample = preprocess_image(sample, image_processor)
return rearrange(sample, "(b t f) c h w -> b t f c h w", t=1, f=1)


def preprocess_laion_text(sample, tokenizer, max_tokens=32):
"""
Preprocess text for LAION.
Preprocess text for LAION. Applied to a batch of captions.
Captions are truncated to 32 tokens by default.
"""
tokenizer.padding_side = "right"
Expand Down Expand Up @@ -159,6 +161,7 @@ def preprocess_interleaved(
"""
Preprocess an interleaved image-text sequence, either by calling preprocess_gpt_interleaved (if the sequence
is ChatGPT-generated) or by preprocessing in this function (if the sequences is from MMC4).
Applied to a single sequence.
"""
info = json.loads(sample[0])
if "is_gpt" in info:
Expand Down Expand Up @@ -244,7 +247,7 @@ def preprocess_interleaved(
elif (
num_images == 1 and random.random() <= 0.5
): # 50% chance of keeping single image samples
raise ValueError("Only one image in sample")
raise ValueError("Only one images in sample")

# avoid the situation where there's one <image> token and it's at the end
if (
Expand All @@ -259,7 +262,7 @@ def preprocess_interleaved(
)

return (
rearrange(images_tensors, "b (t f) c h w -> b t f c h w", f=1),
rearrange(images_tensors, "(t f) c h w -> t f c h w", f=1),
(text_tensor["input_ids"], text_tensor["attention_mask"]),
)

Expand Down Expand Up @@ -326,11 +329,17 @@ def get_mmc4_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
]
)

def zip_text(batch):
"""Unpack from [(input_ids, attention_mask), ...] to (input_ids, attention_mask)"""
input_ids, attention_mask = tuple(zip(*batch))
return torch.stack(input_ids).squeeze(1), torch.stack(attention_mask).squeeze(1)

pipeline.extend(
[
wds.to_tuple("json", handler=log_and_continue),
wds.map(preprocess_fn, handler=log_and_continue),
wds.batched(args.batch_size_mmc4, partial=False),
wds.map_tuple(lambda x: x, zip_text, handler=log_and_continue),
]
)

Expand Down Expand Up @@ -399,7 +408,9 @@ def get_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
pipeline = [wds.SimpleShardList(input_shards)]

# create two preprocess functions that take in the passed in image_processor and tokenizer
preprocess_image_fn = functools.partial(preprocess_laion_image, image_processor=image_processor)
preprocess_image_fn = functools.partial(
preprocess_laion_image, image_processor=image_processor
)
preprocess_text_fn = functools.partial(preprocess_laion_text, tokenizer=tokenizer)

# at this point we have an iterator over all the shards
Expand Down Expand Up @@ -490,4 +501,4 @@ def get_data(args, image_processor, tokenizer, dataset_type, epoch=0):
args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer
)
else:
raise ValueError(f"Unsupported dataset: {dataset_type}")
raise ValueError(f"Unsupported dataset: {dataset_type}")
Loading

0 comments on commit cd4f3aa

Please sign in to comment.