Skip to content

Commit

Permalink
use minimal prepare_inputs_for_generation in ORTModelForSpeechSeq2Seq
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Aug 2, 2024
1 parent 3fe0cac commit b3948b9
Showing 1 changed file with 15 additions and 36 deletions.
51 changes: 15 additions & 36 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
GenerationConfig,
Pix2StructForConditionalGeneration, # Pix2struct does not support AutoModel
Pix2StructForConditionalGeneration,
WhisperForConditionalGeneration,
)
from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
Expand Down Expand Up @@ -1281,7 +1282,6 @@ def get_encoder(self) -> ORTEncoder:
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache
@staticmethod
def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
print("REORDER CACHE CALLED")
reordered_past = ()
for layer_past in past:
# Cached cross_attention states don't have to be reordered -> they are always the same
Expand Down Expand Up @@ -1381,51 +1381,27 @@ def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
decoder_attention_mask=None,
cache_position=None,
**kwargs,
):
decoder_position_ids = None
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)

past_length = 0
# cut decoder_input_ids if past is used
if past_key_values is not None:
if isinstance(past_key_values, EncoderDecoderCache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
else:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]

if cache_position is None:
cache_position = torch.arange(
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
)
elif use_cache:
cache_position = cache_position[-decoder_input_ids.shape[1] :]
decoder_input_ids = decoder_input_ids[:, -1:]

return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"cache_position": cache_position,
}

def get_encoder(self) -> ORTEncoder:
Expand Down Expand Up @@ -1460,6 +1436,9 @@ class _ORTModelForWhisper(WhisperGenerationMixin, ORTModelForSpeechSeq2Seq):
Whisper implements its own generate() method.
"""

auto_model_class = WhisperForConditionalGeneration
prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation

@classmethod
def _from_pretrained(
cls,
Expand Down

0 comments on commit b3948b9

Please sign in to comment.