Skip to content

Commit

Permalink
no need to untie embeddings for fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
i-gao committed Sep 16, 2023
1 parent ccfcb0f commit 303e707
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 15 deletions.
8 changes: 0 additions & 8 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def create_model_and_transforms(
decoder_layers_attr_name: str = None,
cache_dir: Optional[str] = None,
gradient_checkpointing: bool = False,
untie_embeddings: bool = False,
verbose: bool = True,
**model_kwargs,
):
Expand All @@ -40,7 +39,6 @@ def create_model_and_transforms(
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
untie_embeddings (bool, optional): whether to untie the input and output embeddings of the language model. Defaults to False.
verbose (bool, optional): whether to print model info. Defaults to True.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Expand Down Expand Up @@ -80,12 +78,6 @@ def create_model_and_transforms(
cache_dir=cache_dir,
)
check_embedding_fns(lang_model)
if untie_embeddings:
print("Untying language model embeddings...")
lang_model.get_output_embeddings().weight = nn.Parameter(
lang_model.get_output_embeddings().weight.clone()
)
lang_model.config.update({"tie_word_embeddings": False})

# init the model
if decoder_layers_attr_name is None:
Expand Down
2 changes: 2 additions & 0 deletions open_flamingo/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def extend_instance(obj, mixin):
base_cls_name, (mixin, base_cls), {}
) # mixin needs to go first for our forward() logic to work


def hasattr_recursive(obj, att):
"""
Check if obj has nested attribute
Expand All @@ -25,6 +26,7 @@ def hasattr_recursive(obj, att):
except:
return False


def getattr_recursive(obj, att):
"""
Return nested attribute of obj
Expand Down
6 changes: 0 additions & 6 deletions open_flamingo/train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,4 @@ By default, `train.py` uses Pytorch's [DistributedDataParallel](https://pytorch.
To use [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html), make sure to use Pytorch Nightly (> 2.0.1), and use the `--fsdp` flag.
To use [DeepSpeed](https://github.com/microsoft/DeepSpeed), use the `--deepspeed` flag.
(Note that you should use *either* FSDP or Deepspeed, not both.)


Some notes on FSDP:

* Our current FSDP wrapping strategy does not permit training language model embeddings that use tied weights (i.e., tied input / output embeddings). To train such models with FSDP, the language model embeddings must be frozen with the `--freeze_lm_embeddings` flag.

We also implement gradient checkpointing and mixed precision training. Use the `--gradient_checkpointing` and `--precision` arguments respectively.
1 change: 0 additions & 1 deletion open_flamingo/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def main():
args.lm_path,
args.tokenizer_path if args.tokenizer_path else args.lm_path,
model_family=args.model_family,
untie_embeddings=False, # untie embeddings for FSDP
use_local_files=args.offline,
gradient_checkpointing=args.gradient_checkpointing,
verbose=(args.rank == 0),
Expand Down

0 comments on commit 303e707

Please sign in to comment.