Skip to content

Commit

Permalink
fix(generation): adapt to new GenerationMixin
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed May 30, 2024
1 parent f5a34e2 commit 35de60a
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,17 +661,9 @@ def generate(
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

# 3. Define model inputs
# inputs_tensor has to be defined
Expand Down Expand Up @@ -700,6 +692,9 @@ def generate(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)

device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
if (
Expand All @@ -725,7 +720,6 @@ def generate(
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
device=inputs_tensor.device,
)
else:
Expand Down

0 comments on commit 35de60a

Please sign in to comment.