Skip to content

Commit

Permalink
Fix starcoder ORT integration (#1722)
Browse files Browse the repository at this point in the history
* fix starcoder ort

* fix pix2struct as well
  • Loading branch information
fxmarty authored Feb 27, 2024
1 parent 80e89f1 commit c7cc312
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 21 deletions.
5 changes: 4 additions & 1 deletion optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def forward(

outputs_to_not_bind = self.get_outputs_not_to_bind(use_merged_cache)

model_inputs = [input_ids]
# TODO: fix transformers generate to have contiguous input_ids here already
# For an unknown reason, calling `contiguous()` here is necessary to not have errors
# on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding.g
model_inputs = [input_ids.contiguous()]

if "encoder_hidden_states" in self.input_names:
model_inputs.append(encoder_hidden_states)
Expand Down
7 changes: 7 additions & 0 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,13 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_
)
return model_inputs

# Copied from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)


class ORTBloomForCausalLM(ORTModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
Expand Down
85 changes: 65 additions & 20 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2322,8 +2322,8 @@ def test_merge_from_onnx_and_save(self, model_arch):
self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents)
self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents)

@parameterized.expand(grid_parameters(FULL_GRID))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool):
@parameterized.expand(grid_parameters({**FULL_GRID, "num_beams": [1, 3]}))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, num_beams: int):
use_io_binding = None
if use_cache is False:
use_io_binding = False
Expand Down Expand Up @@ -2384,17 +2384,19 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
if model_arch == "falcon":
# TODO: remove once https://github.com/huggingface/transformers/pull/26873 is released, falcon is broken in transformers
new_tokens = 5

onnx_outputs = onnx_model.generate(
**tokens,
num_beams=1,
num_beams=num_beams,
do_sample=False,
min_new_tokens=new_tokens,
max_new_tokens=new_tokens,
eos_token_id=None,
)

transformers_outputs = transformers_model.generate(
**tokens,
num_beams=1,
num_beams=num_beams,
do_sample=False,
min_new_tokens=new_tokens,
max_new_tokens=new_tokens,
Expand Down Expand Up @@ -4123,11 +4125,23 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True],
"use_merged": [False, True],
"num_beams": [1, 3],
}
)
)
@require_torch_gpu
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
self,
test_name: str,
model_arch: str,
use_cache: bool,
use_merged: bool,
num_beams: int,
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand Down Expand Up @@ -4159,8 +4173,8 @@ def test_compare_generation_to_io_binding(

tokenizer = get_preprocessor(model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt").to("cuda")
onnx_outputs = onnx_model.generate(**tokens, num_beams=5)
io_outputs = io_model.generate(**tokens, num_beams=5)
onnx_outputs = onnx_model.generate(**tokens, num_beams=num_beams)
io_outputs = io_model.generate(**tokens, num_beams=num_beams)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
Expand Down Expand Up @@ -4555,12 +4569,24 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True],
"use_merged": [False, True],
"num_beams": [1, 5],
}
)
)
@require_torch_gpu
@pytest.mark.cuda_ep_test
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
self,
test_name: str,
model_arch: str,
use_cache: bool,
use_merged: bool,
num_beams: int,
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand All @@ -4586,8 +4612,8 @@ def test_compare_generation_to_io_binding(
data = self._generate_random_audio_data()
features = processor.feature_extractor(data, return_tensors="pt").to("cuda")

onnx_outputs = onnx_model.generate(**features, num_beams=5)
io_outputs = io_model.generate(**features, num_beams=5)
onnx_outputs = onnx_model.generate(**features, num_beams=num_beams)
io_outputs = io_model.generate(**features, num_beams=num_beams)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
Expand Down Expand Up @@ -4920,12 +4946,19 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True],
"use_merged": [False, True],
"num_beams": [1, 3],
}
)
)
@require_torch_gpu
@pytest.mark.cuda_ep_test
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool, num_beams: int
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand All @@ -4951,8 +4984,8 @@ def test_compare_generation_to_io_binding(
data = self._get_sample_image()
features = feature_extractor(data, return_tensors="pt").to("cuda")

onnx_outputs = onnx_model.generate(**features, num_beams=5)
io_outputs = io_model.generate(**features, num_beams=5)
onnx_outputs = onnx_model.generate(**features, num_beams=num_beams)
io_outputs = io_model.generate(**features, num_beams=num_beams)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
Expand Down Expand Up @@ -5336,10 +5369,22 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
grid_parameters(
{
"model_arch": SUPPORTED_ARCHITECTURES,
"use_cache": [True],
"use_merged": [False, True],
"num_beams": [1, 3],
}
)
)
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
self,
test_name: str,
model_arch: str,
use_cache: bool,
use_merged: bool,
num_beams: int,
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand All @@ -5362,8 +5407,8 @@ def test_compare_generation_to_io_binding(
inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt")
del inputs["decoder_attention_mask"]
del inputs["decoder_input_ids"]
onnx_outputs = onnx_model.generate(**inputs, num_beams=5)
io_outputs = io_model.generate(**inputs, num_beams=5)
onnx_outputs = onnx_model.generate(**inputs, num_beams=num_beams)
io_outputs = io_model.generate(**inputs, num_beams=num_beams)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))
Expand Down

0 comments on commit c7cc312

Please sign in to comment.