Skip to content

Commit

Permalink
Fix vision encoder decoder that may not cache cross-attention (huggin…
Browse files Browse the repository at this point in the history
…gface#1210)

* fix vision encoder decoder

* add test
  • Loading branch information
fxmarty authored and baskrahmer committed Jul 22, 2023
1 parent a4dbf8b commit 4731e75
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 15 deletions.
10 changes: 8 additions & 2 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,13 @@ def __init__(
):
super().__init__(session, parent_model)

if self.parent_model.use_merged is False and self.use_past is True:
# We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2
# can be used but do not support KV caching for the cross-attention key/values, see:
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L302-L311
# This attribute is used to avoid returning cross-attention KV-cache in this case.
self.no_cross_attention_cache = getattr(self.parent_model, "no_cross_attention_cache", False)

if (not self.parent_model.use_merged and self.use_past) or self.no_cross_attention_cache:
self.num_pkv = 2
else:
# When using a merged model, we always have the same number of output whether we use past key values or not,
Expand Down Expand Up @@ -688,7 +694,7 @@ def forward(
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
# * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant)
if self.use_past is False or use_merged_no_cache:
if not self.use_past or use_merged_no_cache or self.no_cross_attention_cache:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
Expand Down
5 changes: 5 additions & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,11 @@ def __init__(
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
# There are probably other archs that do not support cross attention KV cache, but only
# this one seem popular on the Hub.
if config.decoder.model_type == "gpt2":
self.no_cross_attention_cache = True

super().__init__(
encoder_session,
decoder_session,
Expand Down
54 changes: 41 additions & 13 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4023,28 +4023,56 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
feature_extractor, tokenizer = self._get_preprocessors(model_id)

data = self._get_sample_image()
features = feature_extractor(data, return_tensors="pt")

start_token = "<s>"
decoder_start_token_id = tokenizer.encode(start_token)[0]
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}

with torch.no_grad():
transformers_outputs = transformers_model(**features, **decoder_inputs)
extra_inputs = [{}, {}]

for input_type in ["pt", "np"]:
features = feature_extractor(data, return_tensors=input_type)
if use_cache and False:
# TODO: the dims will fail with other models
fake_pkv = tuple((torch.rand(1, 4, 1, 8), torch.rand(1, 4, 1, 8)) for _ in range(5))
extra_inputs[1]["past_key_values"] = fake_pkv

if input_type == "np":
decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id}
for extra_inps in extra_inputs:
features = feature_extractor(data, return_tensors="pt")
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}

onnx_outputs = onnx_model(**features, **decoder_inputs)
with torch.no_grad():
transformers_outputs = transformers_model(**features, **decoder_inputs, **extra_inps)
for input_type in ["pt", "np"]:
features = feature_extractor(data, return_tensors=input_type)

self.assertTrue("logits" in onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])
if input_type == "np":
decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id}

# Compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3))
if "past_key_values" in extra_inps:
del extra_inps["past_key_values"] # test only with pytorch

onnx_outputs = onnx_model(**features, **decoder_inputs, **extra_inps)

self.assertTrue("logits" in onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])

if use_cache:
self.assertEqual(
len(onnx_outputs["past_key_values"]), len(transformers_outputs["past_key_values"])
)
self.assertEqual(
len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0])
)
for i, _ in enumerate(onnx_outputs["past_key_values"]):
for j, ort_pkv in enumerate(onnx_outputs["past_key_values"][i]):
trfs_pkv = transformers_outputs["past_key_values"][i][j]
self.assertTrue(
torch.allclose(ort_pkv, trfs_pkv, atol=1e-3),
f" Maxdiff: {torch.abs(ort_pkv - trfs_pkv).max()}",
)

# Compare tensor outputs
self.assertTrue(
torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3)
)

gc.collect()

Expand Down

0 comments on commit 4731e75

Please sign in to comment.