From c7cc31261044890eaafe9071499c37d5bbad1bbd Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Tue, 27 Feb 2024 15:27:26 +0100 Subject: [PATCH] Fix starcoder ORT integration (#1722) * fix starcoder ort * fix pix2struct as well --- optimum/onnxruntime/base.py | 5 +- optimum/onnxruntime/modeling_decoder.py | 7 ++ tests/onnxruntime/test_modeling.py | 85 +++++++++++++++++++------ 3 files changed, 76 insertions(+), 21 deletions(-) diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index e36ca798d5..bf9c80a86c 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -260,7 +260,10 @@ def forward( outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache) - model_inputs = [input_ids] + # TODO: fix transformers generate to have contiguous input_ids here already + # For an unknown reason, calling `contiguous()` here is necessary to not have errors + # on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.g + model_inputs = [input_ids.contiguous()] if "encoder_hidden_states" in self.input_names: model_inputs.append(encoder_hidden_states) diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 0f3a525e9a..661e7faa25 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -722,6 +722,13 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ ) return model_inputs + # Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values) + class ORTBloomForCausalLM(ORTModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0e65dd8b21..b9f7cdcd4d 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2322,8 +2322,8 @@ def test_merge_from_onnx_and_save(self, model_arch): self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents) self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents) - @parameterized.expand(grid_parameters(FULL_GRID)) - def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool): + @parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 3]})) + def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int): use_io_binding = None if use_cache is False: use_io_binding = False @@ -2384,17 +2384,19 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach if model_arch == "falcon": # TODO: remove once https://github.com/huggingface/transformers/pull/26873 is released, falcon is broken in transformers new_tokens = 5 + onnx_outputs = onnx_model.generate( **tokens, - num_beams=1, + num_beams=num_beams, do_sample=False, min_new_tokens=new_tokens, max_new_tokens=new_tokens, eos_token_id=None, ) + transformers_outputs = transformers_model.generate( **tokens, - num_beams=1, + num_beams=num_beams, do_sample=False, min_new_tokens=new_tokens, max_new_tokens=new_tokens, @@ -4123,11 +4125,23 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: gc.collect() @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "use_cache": [True], + "use_merged": [False, True], + "num_beams": [1, 3], + } + ) ) @require_torch_gpu def test_compare_generation_to_io_binding( - self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool + self, + test_name: str, + model_arch: str, + use_cache: bool, + use_merged: bool, + num_beams: int, ): if use_cache is False and use_merged is True: self.skipTest("use_cache=False, use_merged=True are uncompatible") @@ -4159,8 +4173,8 @@ def test_compare_generation_to_io_binding( tokenizer = get_preprocessor(model_id) tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda") - onnx_outputs = onnx_model.generate(**tokens, num_beams=5) - io_outputs = io_model.generate(**tokens, num_beams=5) + onnx_outputs = onnx_model.generate(**tokens, num_beams=num_beams) + io_outputs = io_model.generate(**tokens, num_beams=num_beams) # compare tensor outputs self.assertTrue(torch.equal(onnx_outputs, io_outputs)) @@ -4555,12 +4569,24 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: gc.collect() @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "use_cache": [True], + "use_merged": [False, True], + "num_beams": [1, 5], + } + ) ) @require_torch_gpu @pytest.mark.cuda_ep_test def test_compare_generation_to_io_binding( - self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool + self, + test_name: str, + model_arch: str, + use_cache: bool, + use_merged: bool, + num_beams: int, ): if use_cache is False and use_merged is True: self.skipTest("use_cache=False, use_merged=True are uncompatible") @@ -4586,8 +4612,8 @@ def test_compare_generation_to_io_binding( data = self._generate_random_audio_data() features = processor.feature_extractor(data, return_tensors="pt").to("cuda") - onnx_outputs = onnx_model.generate(**features, num_beams=5) - io_outputs = io_model.generate(**features, num_beams=5) + onnx_outputs = onnx_model.generate(**features, num_beams=num_beams) + io_outputs = io_model.generate(**features, num_beams=num_beams) # compare tensor outputs self.assertTrue(torch.equal(onnx_outputs, io_outputs)) @@ -4920,12 +4946,19 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: gc.collect() @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "use_cache": [True], + "use_merged": [False, True], + "num_beams": [1, 3], + } + ) ) @require_torch_gpu @pytest.mark.cuda_ep_test def test_compare_generation_to_io_binding( - self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool + self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool, num_beams: int ): if use_cache is False and use_merged is True: self.skipTest("use_cache=False, use_merged=True are uncompatible") @@ -4951,8 +4984,8 @@ def test_compare_generation_to_io_binding( data = self._get_sample_image() features = feature_extractor(data, return_tensors="pt").to("cuda") - onnx_outputs = onnx_model.generate(**features, num_beams=5) - io_outputs = io_model.generate(**features, num_beams=5) + onnx_outputs = onnx_model.generate(**features, num_beams=num_beams) + io_outputs = io_model.generate(**features, num_beams=num_beams) # compare tensor outputs self.assertTrue(torch.equal(onnx_outputs, io_outputs)) @@ -5336,10 +5369,22 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: gc.collect() @parameterized.expand( - grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]}) + grid_parameters( + { + "model_arch": SUPPORTED_ARCHITECTURES, + "use_cache": [True], + "use_merged": [False, True], + "num_beams": [1, 3], + } + ) ) def test_compare_generation_to_io_binding( - self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool + self, + test_name: str, + model_arch: str, + use_cache: bool, + use_merged: bool, + num_beams: int, ): if use_cache is False and use_merged is True: self.skipTest("use_cache=False, use_merged=True are uncompatible") @@ -5362,8 +5407,8 @@ def test_compare_generation_to_io_binding( inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt") del inputs["decoder_attention_mask"] del inputs["decoder_input_ids"] - onnx_outputs = onnx_model.generate(**inputs, num_beams=5) - io_outputs = io_model.generate(**inputs, num_beams=5) + onnx_outputs = onnx_model.generate(**inputs, num_beams=num_beams) + io_outputs = io_model.generate(**inputs, num_beams=num_beams) # compare tensor outputs self.assertTrue(torch.equal(onnx_outputs, io_outputs))