Skip to content

Commit

Permalink
added io tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Aug 17, 2023
1 parent eb55f96 commit c629059
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 2 deletions.
3 changes: 2 additions & 1 deletion optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ def compute_encoder_known_output_shapes(self, pixel_values: torch.FloatTensor) -
)
else:
raise ValueError(
f"Unsupported encoder model type {self.normalized_config.config.model_type} for VisionEncoderDecoder."
f"Unsupported encoder model type {self.normalized_config.config.model_type} for ORTForVisionSeq2Seq with IOBinding."
"Currently supported models are vit and donut-swin."
"Please submit a PR to add support for this model type."
)

Expand Down
87 changes: 87 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4187,6 +4187,93 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str):
f" speedup: {without_pkv_timer.elapsed / with_pkv_timer.elapsed:.3f}",
)

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
)
@require_torch_gpu
@pytest.mark.gpu_test
def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")

model_args = {
"test_name": test_name,
"model_arch": model_arch,
"use_cache": use_cache,
"use_merged": use_merged,
}
self._setup(model_args)

model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to(
"cuda"
)
io_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to(
"cuda"
)

self.assertFalse(onnx_model.use_io_binding)
self.assertTrue(io_model.use_io_binding)

feature_extractor, tokenizer = self._get_preprocessors(model_id)

data = self._get_sample_image()
features = feature_extractor([data] * 2, return_tensors="pt").to("cuda")

decoder_start_token_id = onnx_model.config.decoder.bos_token_id
decoder_inputs = {"decoder_input_ids": torch.ones((2, 1), dtype=torch.long) * decoder_start_token_id}

onnx_outputs = onnx_model(**features, **decoder_inputs)
io_outputs = io_model(**features, **decoder_inputs)

self.assertTrue("logits" in io_outputs)
self.assertIsInstance(io_outputs.logits, torch.Tensor)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs.logits, io_outputs.logits))

gc.collect()

@parameterized.expand(
grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True], "use_merged": [False, True]})
)
@require_torch_gpu
@pytest.mark.gpu_test
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")

model_args = {
"test_name": test_name,
"model_arch": model_arch,
"use_cache": use_cache,
"use_merged": use_merged,
}
self._setup(model_args)

model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=False).to(
"cuda"
)
io_model = ORTModelForVision2Seq.from_pretrained(self.onnx_model_dirs[test_name], use_io_binding=True).to(
"cuda"
)

feature_extractor, tokenizer = self._get_preprocessors(model_id)

data = self._get_sample_image()
features = feature_extractor(data, return_tensors="pt").to("cuda")

onnx_outputs = onnx_model.generate(**features, num_beams=5)
io_outputs = io_model.generate(**features, num_beams=5)

# compare tensor outputs
self.assertTrue(torch.equal(onnx_outputs, io_outputs))

gc.collect()


class ORTModelForCustomTasksIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES_WITH_MODEL_ID = {
Expand Down
2 changes: 1 addition & 1 deletion tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"deberta": "hf-internal-testing/tiny-random-DebertaModel",
"deberta_v2": "hf-internal-testing/tiny-random-DebertaV2Model",
"deit": "hf-internal-testing/tiny-random-DeiTModel",
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
"convnext": "hf-internal-testing/tiny-random-convnext",
"detr": "hf-internal-testing/tiny-random-detr",
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
Expand Down Expand Up @@ -99,7 +100,6 @@
"xlm_roberta": "hf-internal-testing/tiny-xlm-roberta",
"vision-encoder-decoder": "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2",
"trocr": "microsoft/trocr-small-handwritten",
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
}

SEED = 42
Expand Down

0 comments on commit c629059

Please sign in to comment.