From c629059bd1b190a6a90ecda1e542dd5ded7885da Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 17 Aug 2023 14:23:48 +0200 Subject: [PATCH] added io tests --- optimum/onnxruntime/modeling_seq2seq.py | 3 +- tests/onnxruntime/test_modeling.py | 87 ++++++++++++++++++++ tests/onnxruntime/utils_onnxruntime_tests.py | 2 +- 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 8909f7b430..5f637b8989 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -403,7 +403,8 @@ def compute_encoder_known_output_shapes(self, pixel_values: torch.FloatTensor) - ) else: raise ValueError( - f"Unsupported encoder model type {self.normalized_config.config.model_type} for VisionEncoderDecoder." + f"Unsupported encoder model type {self.normalized_config.config.model_type} for ORTForVisionSeq2Seq with IOBinding." + "Currently supported models are vit and donut-swin." "Please submit a PR to add support for this model type." ) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 38eec14507..e421367f8d 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4187,6 +4187,93 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}", ) + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + ) + @require_torch_gpu + @pytest.mark.gpu_test + def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): + if use_cache is False and use_merged is True: + self.skipTest("use_cache=False, use_merged=True are uncompatible") + + model_args = { + "test_name": test_name, + "model_arch": model_arch, + "use_cache": use_cache, + "use_merged": use_merged, + } + self._setup(model_args) + + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to( + "cuda" + ) + io_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to( + "cuda" + ) + + self.assertFalse(onnx_model.use_io_binding) + self.assertTrue(io_model.use_io_binding) + + feature_extractor, tokenizer = self._get_preprocessors(model_id) + + data = self._get_sample_image() + features = feature_extractor([data] * 2, return_tensors="pt").to("cuda") + + decoder_start_token_id = onnx_model.config.decoder.bos_token_id + decoder_inputs = {"decoder_input_ids": torch.ones((2, 1), dtype=torch.long) * decoder_start_token_id} + + onnx_outputs = onnx_model(**features, **decoder_inputs) + io_outputs = io_model(**features, **decoder_inputs) + + self.assertTrue("logits" in io_outputs) + self.assertIsInstance(io_outputs.logits, torch.Tensor) + + # compare tensor outputs + self.assertTrue(torch.equal(onnx_outputs.logits, io_outputs.logits)) + + gc.collect() + + @parameterized.expand( + grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + ) + @require_torch_gpu + @pytest.mark.gpu_test + def test_compare_generation_to_io_binding( + self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool + ): + if use_cache is False and use_merged is True: + self.skipTest("use_cache=False, use_merged=True are uncompatible") + + model_args = { + "test_name": test_name, + "model_arch": model_arch, + "use_cache": use_cache, + "use_merged": use_merged, + } + self._setup(model_args) + + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to( + "cuda" + ) + io_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to( + "cuda" + ) + + feature_extractor, tokenizer = self._get_preprocessors(model_id) + + data = self._get_sample_image() + features = feature_extractor(data, return_tensors="pt").to("cuda") + + onnx_outputs = onnx_model.generate(**features, num_beams=5) + io_outputs = io_model.generate(**features, num_beams=5) + + # compare tensor outputs + self.assertTrue(torch.equal(onnx_outputs, io_outputs)) + + gc.collect() + class ORTModelForCustomTasksIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = { diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 4442583d8e..1399c14091 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -46,6 +46,7 @@ "deberta": "hf-internal-testing/tiny-random-DebertaModel", "deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model", "deit": "hf-internal-testing/tiny-random-DeiTModel", + "donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder", "convnext": "hf-internal-testing/tiny-random-convnext", "detr": "hf-internal-testing/tiny-random-detr", "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", @@ -99,7 +100,6 @@ "xlm_roberta": "hf-internal-testing/tiny-xlm-roberta", "vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2", "trocr": "microsoft/trocr-small-handwritten", - "donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder", } SEED = 42