Skip to content

Commit

Permalink
More fixes following transformers 4.32 release (#1311)
Browse files Browse the repository at this point in the history
* more fixes

* nit

* remove duplicate test

* nit bis
  • Loading branch information
fxmarty authored Aug 23, 2023
1 parent 7e932ec commit 8289f28
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 13 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,7 @@ class SamOnnxConfig(OnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator)
DEFAULT_ONNX_OPSET = 12 # einsum op not supported with opset 11
MIN_TORCH_VERSION = version.parse("2.0.99") # See: https://github.com/huggingface/optimum/pull/1301

def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"):
super().__init__(config, task)
Expand Down
32 changes: 32 additions & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from transformers.file_utils import add_start_docstrings_to_model_forward
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES

import onnxruntime as ort

Expand Down Expand Up @@ -1083,6 +1084,37 @@ class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin
auto_model_class = AutoModelForSpeechSeq2Seq
main_input_name = "input_features"

def __init__(
self,
encoder_session: ort.InferenceSession,
decoder_session: ort.InferenceSession,
config: "PretrainedConfig",
onnx_paths: List[str],
decoder_with_past_session: Optional[ort.InferenceSession] = None,
use_cache: bool = True,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
super().__init__(
encoder_session=encoder_session,
decoder_session=decoder_session,
config=config,
onnx_paths=onnx_paths,
decoder_with_past_session=decoder_with_past_session,
use_cache=use_cache,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
**kwargs,
)
# Following a breaking change in transformers that relies directly on the mapping name and not on the greedy model mapping (that can be extended), we need to hardcode the ortmodel in this dictionary. Other pipelines do not seem to have controlflow depending on the mapping name.
# See: https://github.com/huggingface/transformers/pull/24960/files
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES["ort_speechseq2seq"] = self.__class__.__name__

def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder:
return ORTEncoderForSpeech(session, self)

Expand Down
2 changes: 1 addition & 1 deletion tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"resnet": "hf-internal-testing/tiny-random-resnet",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
"sam": "fxmarty/sam-vit-tiny-random",
# "sam": "fxmarty/sam-vit-tiny-random", # TODO: re-enable once PyTorch 2.1 is released, see https://github.com/huggingface/optimum/pull/1301
"segformer": "hf-internal-testing/tiny-random-SegformerModel",
"splinter": "hf-internal-testing/tiny-random-SplinterModel",
"squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel",
Expand Down
6 changes: 0 additions & 6 deletions tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ def _onnx_export(
except MinimumVersionError as e:
pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}")

def test_all_models_tested(self):
# make sure we test all models
missing_models_set = TasksManager._SUPPORTED_CLI_MODEL_TYPE - set(PYTORCH_EXPORT_MODELS_TINY.keys())
if len(missing_models_set) > 0:
self.fail(f"Not testing all models. Missing models: {missing_models_set}")

@parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items())
@require_torch
@require_vision
Expand Down
3 changes: 2 additions & 1 deletion tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ def _onnx_export_sd(self, model_type: str, model_name: str, device="cpu"):
def test_all_models_tested(self):
# make sure we test all models
missing_models_set = TasksManager._SUPPORTED_CLI_MODEL_TYPE - set(PYTORCH_EXPORT_MODELS_TINY.keys())
if len(missing_models_set) > 0:
assert "sam" in missing_models_set # See exporters_utils.py
if len(missing_models_set) > 1:
self.fail(f"Not testing all models. Missing models: {missing_models_set}")

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY))
Expand Down
8 changes: 4 additions & 4 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,9 +1387,9 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin):
"distilbert",
"electra",
"flaubert",
"gpt2",
"gpt_neo",
"gptj",
# "gpt2", # see tasks.py
# "gpt_neo", # see tasks.py
# "gptj", # see tasks.py
"ibert",
# TODO: these two should be supported, but require image inputs not supported in ORTModel
# "layoutlm"
Expand Down Expand Up @@ -1418,7 +1418,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = ORTModelForSequenceClassification.from_pretrained(MODEL_NAMES["t5"], export=True)

self.assertIn("Unrecognized configuration class", str(context.exception))
self.assertIn("that is a custom or unsupported", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down
2 changes: 1 addition & 1 deletion tests/onnxruntime/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class ORTOptimizerTest(unittest.TestCase):
# (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-big_bird"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-distilbert"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-electra"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-gpt2"),
(ORTModelForCausalLM, "hf-internal-testing/tiny-random-gpt2"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-roberta"),
(ORTModelForSequenceClassification, "hf-internal-testing/tiny-xlm-roberta"),
)
Expand Down

0 comments on commit 8289f28

Please sign in to comment.