Skip to content

Commit

Permalink
Change generate function to pass args directly to huggingface.
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidMChan committed Aug 2, 2023
1 parent bcf220c commit 68e9a91
Showing 1 changed file with 18 additions and 38 deletions.
56 changes: 18 additions & 38 deletions open_flamingo/src/flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,7 @@ def generate(
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
num_beams=1,
min_new_tokens=None,
max_new_tokens=None,
temperature=1.0,
top_k=0,
top_p=1.0,
no_repeat_ngram_size=0,
repetition_penalty=1.0,
prefix_allowed_tokens_fn=None,
length_penalty=1.0,
num_return_sequences=1,
do_sample=False,
early_stopping=False,
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Expand All @@ -150,44 +138,36 @@ def generate(
currently only F=1 is supported (single-frame videos)
lang_x (torch.Tensor): Language input
shape (B, T_txt)
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
num_beams (int, optional): Number of beams. Defaults to 1.
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
temperature (float, optional): Temperature. Defaults to 1.0.
top_k (int, optional): Top k. Defaults to 0.
top_p (float, optional): Top p. Defaults to 1.0.
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
length_penalty (float, optional): Length penalty. Defaults to 1.0.
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
do_sample (bool, optional): Do sample. Defaults to False.
early_stopping (bool, optional): Early stopping. Defaults to False.
**kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
num_beams (int, optional): Number of beams. Defaults to 1.
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
temperature (float, optional): Temperature. Defaults to 1.0.
top_k (int, optional): Top k. Defaults to 50.
top_p (float, optional): Top p. Defaults to 1.0.
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
length_penalty (float, optional): Length penalty. Defaults to 1.0.
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
do_sample (bool, optional): Do sample. Defaults to False.
early_stopping (bool, optional): Early stopping. Defaults to False.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
num_beams = kwargs.pop('num_beams', 1)
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)

self.lang_encoder._use_cached_vision_x = True
self._encode_vision_x(vision_x=vision_x)

eos_token_id = kwargs.pop('eos_token_id', self.eoc_token_id)
output = self.lang_encoder.generate(
input_ids=lang_x,
attention_mask=attention_mask,
eos_token_id=self.eoc_token_id,
eos_token_id=eos_token_id,
num_beams=num_beams,
min_new_tokens=min_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
no_repeat_ngram_size=no_repeat_ngram_size,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
early_stopping=early_stopping,
**kwargs,
)

self.lang_encoder.clear_conditioned_layers()
Expand Down

0 comments on commit 68e9a91

Please sign in to comment.