Skip to content

Commit

Permalink
Add can_generate method to fix compatibility with transformers v4.39.0
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Mar 22, 2024
1 parent e6641b0 commit 760c947
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 26 deletions.
4 changes: 0 additions & 4 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,6 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) ->
for layer_past in past
)

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True


class ORTGPTBigCodeForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM.prepare_inputs_for_generation
Expand Down
7 changes: 7 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
SequenceClassifierOutput,
TokenClassifierOutput,
XVectorOutput,
GenerationMixin,
)

import onnxruntime as ort
Expand Down Expand Up @@ -894,6 +895,12 @@ def _cached_file(

return model_cache_path, preprocessors

def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
"""
return isinstance(self, GenerationMixin)


FEATURE_EXTRACTION_EXAMPLE = r"""
Example of feature extraction:
Expand Down
22 changes: 0 additions & 22 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,12 +1111,6 @@ def to(self, device: Union[torch.device, str, int]):

return self

def can_generate(self):
logger.warning(
"ORTModelForConditionalGeneration is an abstract class and is not meant to be used for generation. Please use ORTModelForSeq2SeqLM or ORTModelForSpeechSeq2Seq."
)
return False


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForSeq2SeqLM(ORTModelForConditionalGeneration, GenerationMixin):
Expand Down Expand Up @@ -1262,10 +1256,6 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
)
return reordered_past

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin):
Expand Down Expand Up @@ -1397,10 +1387,6 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
)
return reordered_past

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True

@classmethod
def _from_pretrained(
cls,
Expand Down Expand Up @@ -1986,10 +1972,6 @@ def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
)
return reordered_past

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForPix2Struct(ORTModelForConditionalGeneration, GenerationMixin):
Expand Down Expand Up @@ -2105,7 +2087,3 @@ def get_encoder(self) -> ORTEncoder:
@staticmethod
def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]:
ORTModelForSeq2SeqLM._reorder_cache(past, beam_idx)

def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True

0 comments on commit 760c947

Please sign in to comment.