Skip to content

Commit

Permalink
fix vision2seq tests as it seems to have had always outputed kv cache…
Browse files Browse the repository at this point in the history
… in torch format before
  • Loading branch information
IlyasMoutawwakil committed Jun 4, 2024
1 parent aaf6cd6 commit 20ccd8e
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4766,6 +4766,9 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach

self.assertTrue("logits" in onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])
self.assertTrue(
torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3)
)

if use_cache:
self.assertEqual(
Expand All @@ -4774,19 +4777,17 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
self.assertEqual(
len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0])
)
for i, _ in enumerate(onnx_outputs["past_key_values"]):
for j, ort_pkv in enumerate(onnx_outputs["past_key_values"][i]):
trfs_pkv = transformers_outputs["past_key_values"][i][j]
for i in range(len(onnx_outputs["past_key_values"])):
print(onnx_outputs["past_key_values"][i])
for ort_pkv, trfs_pkv in zip(
onnx_outputs["past_key_values"][i], transformers_outputs["past_key_values"][i]
):
ort_pkv = torch.Tensor(ort_pkv)
self.assertTrue(
torch.allclose(ort_pkv, trfs_pkv, atol=1e-3),
f" Maxdiff: {torch.abs(ort_pkv - trfs_pkv).max()}",
)

# Compare tensor outputs
self.assertTrue(
torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3)
)

gc.collect()

@parameterized.expand(grid_parameters(FULL_GRID))
Expand Down

0 comments on commit 20ccd8e

Please sign in to comment.