diff --git a/.gitignore b/.gitignore index fc144471..4480e2fc 100644 --- a/.gitignore +++ b/.gitignore @@ -140,4 +140,7 @@ wandb # Pyre type checker .pyre/ +# Cache +cache/ + __*.sh \ No newline at end of file diff --git a/README.md b/README.md index 4d3c58d2..c90512c2 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,8 @@ model, image_processor, tokenizer = create_model_and_transforms( clip_vision_encoder_pretrained="openai", lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b", tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b", - cross_attn_every_n_layers=1 + cross_attn_every_n_layers=1, + cache_dir="PATH/TO/CACHE/DIR" # Defaults to ~/.cache ) ``` diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 158e08b7..4ac9df40 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -1,3 +1,5 @@ +from typing import Optional + from transformers import AutoModelForCausalLM, AutoTokenizer import open_clip @@ -15,6 +17,7 @@ def create_model_and_transforms( use_local_files: bool = False, decoder_layers_attr_name: str = None, freeze_lm_embeddings: bool = False, + cache_dir: Optional[str] = None, **flamingo_kwargs, ): """ @@ -29,26 +32,24 @@ def create_model_and_transforms( cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1. use_local_files (bool, optional): whether to use local files. Defaults to False. decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. + freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver. + cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights. Returns: Flamingo: Flamingo model from pretrained vision and language encoders Image processor: Pipeline to preprocess input images Tokenizer: A tokenizer for the language model """ vision_encoder, _, image_processor = open_clip.create_model_and_transforms( - clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained + clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained, cache_dir=cache_dir ) # set the vision encoder to output the visual features vision_encoder.visual.output_tokens = True text_tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, - local_files_only=use_local_files, - trust_remote_code=True, + tokenizer_path, local_files_only=use_local_files, trust_remote_code=True, cache_dir=cache_dir ) # add Flamingo special tokens to the tokenizer - text_tokenizer.add_special_tokens( - {"additional_special_tokens": ["<|endofchunk|>", ""]} - ) + text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", ""]}) if text_tokenizer.pad_token is None: # Issue: GPT models don't have a pad token, which we use to # modify labels for the loss. @@ -58,6 +59,7 @@ def create_model_and_transforms( lang_encoder_path, local_files_only=use_local_files, trust_remote_code=True, + cache_dir=cache_dir, ) # hacks for MPT-1B, which doesn't have a get_input_embeddings method @@ -85,9 +87,7 @@ def set_input_embeddings(self, new_embeddings): lang_encoder, text_tokenizer.encode("<|endofchunk|>")[-1], text_tokenizer.encode("")[-1], - vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][ - "width" - ], + vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"], cross_attn_every_n_layers=cross_attn_every_n_layers, **flamingo_kwargs, )