diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 53abf43e01..7248be4e37 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -33,7 +33,8 @@ AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, GenerationConfig, - Pix2StructForConditionalGeneration, # Pix2struct does not support AutoModel + Pix2StructForConditionalGeneration, + WhisperForConditionalGeneration, ) from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput @@ -1281,7 +1282,6 @@ def get_encoder(self) -> ORTEncoder: # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache @staticmethod def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: - print("REORDER CACHE CALLED") reordered_past = () for layer_past in past: # Cached cross_attention states don't have to be reordered -> they are always the same @@ -1381,51 +1381,27 @@ def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, - attention_mask=None, - decoder_attention_mask=None, - cache_position=None, **kwargs, ): - decoder_position_ids = None - if decoder_attention_mask is not None: - decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) - - past_length = 0 + # cut decoder_input_ids if past is used if past_key_values is not None: - if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if decoder_input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = decoder_input_ids.shape[1] - 1 - - decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] - - if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: - decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] - - if cache_position is None: - cache_position = torch.arange( - past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device - ) - elif use_cache: - cache_position = cache_position[-decoder_input_ids.shape[1] :] + decoder_input_ids = decoder_input_ids[:, -1:] return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, - "decoder_attention_mask": decoder_attention_mask, - "decoder_position_ids": decoder_position_ids, - "cache_position": cache_position, } def get_encoder(self) -> ORTEncoder: @@ -1460,6 +1436,9 @@ class _ORTModelForWhisper(WhisperGenerationMixin, ORTModelForSpeechSeq2Seq): Whisper implements its own generate() method. """ + auto_model_class = WhisperForConditionalGeneration + prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation + @classmethod def _from_pretrained( cls,