Skip to content

Commit

Permalink
Add support for custom cache_dir
Browse files Browse the repository at this point in the history
  • Loading branch information
siddk committed Aug 21, 2023
1 parent 914076e commit b30616f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,7 @@ wandb
# Pyre type checker
.pyre/

# Cache
cache/

__*.sh
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_pretrained="openai",
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=1
cross_attn_every_n_layers=1,
cache_dir="PATH/TO/CACHE/DIR" # Defaults to ~/.cache
)
```

Expand Down
20 changes: 10 additions & 10 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from transformers import AutoModelForCausalLM, AutoTokenizer
import open_clip

Expand All @@ -15,6 +17,7 @@ def create_model_and_transforms(
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
freeze_lm_embeddings: bool = False,
cache_dir: Optional[str] = None,
**flamingo_kwargs,
):
"""
Expand All @@ -29,26 +32,24 @@ def create_model_and_transforms(
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
freeze_lm_embeddings (bool, optional): whether to freeze LM input embeddings when configuring Perceiver.
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Tokenizer: A tokenizer for the language model
"""
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained, cache_dir=cache_dir
)
# set the vision encoder to output the visual features
vision_encoder.visual.output_tokens = True

text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=use_local_files,
trust_remote_code=True,
tokenizer_path, local_files_only=use_local_files, trust_remote_code=True, cache_dir=cache_dir
)
# add Flamingo special tokens to the tokenizer
text_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endofchunk|>", "<image>"]}
)
text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
if text_tokenizer.pad_token is None:
# Issue: GPT models don't have a pad token, which we use to
# modify labels for the loss.
Expand All @@ -58,6 +59,7 @@ def create_model_and_transforms(
lang_encoder_path,
local_files_only=use_local_files,
trust_remote_code=True,
cache_dir=cache_dir,
)

# hacks for MPT-1B, which doesn't have a get_input_embeddings method
Expand Down Expand Up @@ -85,9 +87,7 @@ def set_input_embeddings(self, new_embeddings):
lang_encoder,
text_tokenizer.encode("<|endofchunk|>")[-1],
text_tokenizer.encode("<image>")[-1],
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
"width"
],
vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"],
cross_attn_every_n_layers=cross_attn_every_n_layers,
**flamingo_kwargs,
)
Expand Down

0 comments on commit b30616f

Please sign in to comment.