diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index 01db18b7..1940065f 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -126,19 +126,7 @@ def generate( vision_x: torch.Tensor, lang_x: torch.Tensor, attention_mask: torch.Tensor = None, - num_beams=1, - min_new_tokens=None, - max_new_tokens=None, - temperature=1.0, - top_k=0, - top_p=1.0, - no_repeat_ngram_size=0, - repetition_penalty=1.0, - prefix_allowed_tokens_fn=None, - length_penalty=1.0, - num_return_sequences=1, - do_sample=False, - early_stopping=False, + **kwargs, ): """ Generate text conditioned on vision and language inputs. @@ -150,44 +138,36 @@ def generate( currently only F=1 is supported (single-frame videos) lang_x (torch.Tensor): Language input shape (B, T_txt) - max_length (int, optional): Maximum length of the output. Defaults to None. - attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. - num_beams (int, optional): Number of beams. Defaults to 1. - max_new_tokens (int, optional): Maximum new tokens. Defaults to None. - temperature (float, optional): Temperature. Defaults to 1.0. - top_k (int, optional): Top k. Defaults to 0. - top_p (float, optional): Top p. Defaults to 1.0. - no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. - length_penalty (float, optional): Length penalty. Defaults to 1.0. - num_return_sequences (int, optional): Number of return sequences. Defaults to 1. - do_sample (bool, optional): Do sample. Defaults to False. - early_stopping (bool, optional): Early stopping. Defaults to False. + **kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs: + max_length (int, optional): Maximum length of the output. Defaults to None. + attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. + num_beams (int, optional): Number of beams. Defaults to 1. + max_new_tokens (int, optional): Maximum new tokens. Defaults to None. + temperature (float, optional): Temperature. Defaults to 1.0. + top_k (int, optional): Top k. Defaults to 50. + top_p (float, optional): Top p. Defaults to 1.0. + no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. + length_penalty (float, optional): Length penalty. Defaults to 1.0. + num_return_sequences (int, optional): Number of return sequences. Defaults to 1. + do_sample (bool, optional): Do sample. Defaults to False. + early_stopping (bool, optional): Early stopping. Defaults to False. Returns: torch.Tensor: lang_x with generated tokens appended to it """ + num_beams = kwargs.pop('num_beams', 1) if num_beams > 1: vision_x = vision_x.repeat_interleave(num_beams, dim=0) self.lang_encoder._use_cached_vision_x = True self._encode_vision_x(vision_x=vision_x) + eos_token_id = kwargs.pop('eos_token_id', self.eoc_token_id) output = self.lang_encoder.generate( input_ids=lang_x, attention_mask=attention_mask, - eos_token_id=self.eoc_token_id, + eos_token_id=eos_token_id, num_beams=num_beams, - min_new_tokens=min_new_tokens, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - top_p=top_p, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - no_repeat_ngram_size=no_repeat_ngram_size, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - num_return_sequences=num_return_sequences, - do_sample=do_sample, - early_stopping=early_stopping, + **kwargs, ) self.lang_encoder.clear_conditioned_layers()