From 8289f28cd0985bae08e3c4f272b0429ed425f989 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 23 Aug 2023 19:03:05 +0900 Subject: [PATCH] More fixes following transformers 4.32 release (#1311) * more fixes * nit * remove duplicate test * nit bis --- optimum/exporters/onnx/model_configs.py | 1 + optimum/onnxruntime/modeling_seq2seq.py | 32 +++++++++++++++++++ tests/exporters/exporters_utils.py | 2 +- .../exporters/onnx/test_exporters_onnx_cli.py | 6 ---- tests/exporters/onnx/test_onnx_export.py | 3 +- tests/onnxruntime/test_modeling.py | 8 ++--- tests/onnxruntime/test_optimization.py | 2 +- 7 files changed, 41 insertions(+), 13 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 78a52def53..4d41ea1e73 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1213,6 +1213,7 @@ class SamOnnxConfig(OnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyPointsGenerator) DEFAULT_ONNX_OPSET = 12 # einsum op not supported with opset 11 + MIN_TORCH_VERSION = version.parse("2.0.99") # See: https://github.com/huggingface/optimum/pull/1301 def __init__(self, config: "PretrainedConfig", task: str = "feature-extraction"): super().__init__(config, task) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index ee09713390..2908a1c803 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -35,6 +35,7 @@ ) from transformers.file_utils import add_start_docstrings_to_model_forward from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from transformers.models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES import onnxruntime as ort @@ -1083,6 +1084,37 @@ class ORTModelForSpeechSeq2Seq(ORTModelForConditionalGeneration, GenerationMixin auto_model_class = AutoModelForSpeechSeq2Seq main_input_name = "input_features" + def __init__( + self, + encoder_session: ort.InferenceSession, + decoder_session: ort.InferenceSession, + config: "PretrainedConfig", + onnx_paths: List[str], + decoder_with_past_session: Optional[ort.InferenceSession] = None, + use_cache: bool = True, + use_io_binding: Optional[bool] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + preprocessors: Optional[List] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ): + super().__init__( + encoder_session=encoder_session, + decoder_session=decoder_session, + config=config, + onnx_paths=onnx_paths, + decoder_with_past_session=decoder_with_past_session, + use_cache=use_cache, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + preprocessors=preprocessors, + generation_config=generation_config, + **kwargs, + ) + # Following a breaking change in transformers that relies directly on the mapping name and not on the greedy model mapping (that can be extended), we need to hardcode the ortmodel in this dictionary. Other pipelines do not seem to have controlflow depending on the mapping name. + # See: https://github.com/huggingface/transformers/pull/24960/files + MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES["ort_speechseq2seq"] = self.__class__.__name__ + def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder: return ORTEncoderForSpeech(session, self) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index ab4ce97b75..7a20fa4528 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -95,7 +95,7 @@ "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", - "sam": "fxmarty/sam-vit-tiny-random", + # "sam": "fxmarty/sam-vit-tiny-random", # TODO: re-enable once PyTorch 2.1 is released, see https://github.com/huggingface/optimum/pull/1301 "segformer": "hf-internal-testing/tiny-random-SegformerModel", "splinter": "hf-internal-testing/tiny-random-SplinterModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index a92a5d1881..2d9ef98a26 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -129,12 +129,6 @@ def _onnx_export( except MinimumVersionError as e: pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}") - def test_all_models_tested(self): - # make sure we test all models - missing_models_set = TasksManager._SUPPORTED_CLI_MODEL_TYPE - set(PYTORCH_EXPORT_MODELS_TINY.keys()) - if len(missing_models_set) > 0: - self.fail(f"Not testing all models. Missing models: {missing_models_set}") - @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) @require_torch @require_vision diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 9a96d13e47..7e172452cd 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -338,7 +338,8 @@ def _onnx_export_sd(self, model_type: str, model_name: str, device="cpu"): def test_all_models_tested(self): # make sure we test all models missing_models_set = TasksManager._SUPPORTED_CLI_MODEL_TYPE - set(PYTORCH_EXPORT_MODELS_TINY.keys()) - if len(missing_models_set) > 0: + assert "sam" in missing_models_set # See exporters_utils.py + if len(missing_models_set) > 1: self.fail(f"Not testing all models. Missing models: {missing_models_set}") @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index ab39351319..5a8cd52a75 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1387,9 +1387,9 @@ class ORTModelForSequenceClassificationIntegrationTest(ORTModelTestMixin): "distilbert", "electra", "flaubert", - "gpt2", - "gpt_neo", - "gptj", + # "gpt2", # see tasks.py + # "gpt_neo", # see tasks.py + # "gptj", # see tasks.py "ibert", # TODO: these two should be supported, but require image inputs not supported in ORTModel # "layoutlm" @@ -1418,7 +1418,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: _ = ORTModelForSequenceClassification.from_pretrained(MODEL_NAMES["t5"], export=True) - self.assertIn("Unrecognized configuration class", str(context.exception)) + self.assertIn("that is a custom or unsupported", str(context.exception)) @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index ac7d44f9f8..473d12c68b 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -93,7 +93,7 @@ class ORTOptimizerTest(unittest.TestCase): # (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-big_bird"), (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-distilbert"), (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-electra"), - (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-gpt2"), + (ORTModelForCausalLM, "hf-internal-testing/tiny-random-gpt2"), (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-roberta"), (ORTModelForSequenceClassification, "hf-internal-testing/tiny-xlm-roberta"), )