Skip to content

Commit

Permalink
Fixes generation during the evaluation step (#266)
Browse files Browse the repository at this point in the history
* Fix generation during evaluation step

* Fix generation during evaluation step
  • Loading branch information
michaelbenayoun authored Oct 23, 2023
1 parent 56b5cff commit cfc098d
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,11 @@ def generate(
# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
if generation_config.use_cache:
warnings.warn(
"use_cache is not supported for generation on Neuron devices, switching to use_cache=False."
)
model_kwargs["use_cache"] = False

accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
Expand Down Expand Up @@ -1066,7 +1070,7 @@ def greedy_search(
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
seq_length: Optional[int] = int,
seq_length: Optional[int] = None,
streamer: Optional["BaseStreamer"] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, torch.LongTensor]:
Expand Down Expand Up @@ -1269,8 +1273,6 @@ def greedy_search(
else:
next_token_logits = outputs.logits[:, -1, :]

xm.mark_step()

# pre-process distribution
# Move to cpu to handle arbitrary logits_processor
next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu"))
Expand Down

0 comments on commit cfc098d

Please sign in to comment.