Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vision encoder decoder that may not cache cross-attention #1210

Merged
merged 3 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,13 @@ def __init__(
):
super().__init__(session, parent_model)

if self.parent_model.use_merged is False and self.use_past is True:
# We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2
# can be used but do not support KV caching for the cross-attention key/values, see:
# https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L302-L311
# This attribute is used to avoid returning cross-attention KV-cache in this case.
self.no_cross_attention_cache = getattr(self.parent_model, "no_cross_attention_cache", False)

if (not self.parent_model.use_merged and self.use_past) or self.no_cross_attention_cache:
self.num_pkv = 2
else:
# When using a merged model, we always have the same number of output whether we use past key values or not,
Expand Down Expand Up @@ -688,7 +694,7 @@ def forward(
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)
# * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant)
if self.use_past is False or use_merged_no_cache:
if not self.use_past or use_merged_no_cache or self.no_cross_attention_cache:
out_past_key_values = tuple(
out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv)
)
Expand Down
5 changes: 5 additions & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,11 @@ def __init__(
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
# There are probably other archs that do not support cross attention KV cache, but only
# this one seem popular on the Hub.
if config.decoder.model_type == "gpt2":
self.no_cross_attention_cache = True

super().__init__(
encoder_session,
decoder_session,
Expand Down
54 changes: 41 additions & 13 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4023,28 +4023,56 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
feature_extractor, tokenizer = self._get_preprocessors(model_id)

data = self._get_sample_image()
features = feature_extractor(data, return_tensors="pt")

start_token = "<s>"
decoder_start_token_id = tokenizer.encode(start_token)[0]
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}

with torch.no_grad():
transformers_outputs = transformers_model(**features, **decoder_inputs)
extra_inputs = [{}, {}]

for input_type in ["pt", "np"]:
features = feature_extractor(data, return_tensors=input_type)
if use_cache and False:
# TODO: the dims will fail with other models
fake_pkv = tuple((torch.rand(1, 4, 1, 8), torch.rand(1, 4, 1, 8)) for _ in range(5))
extra_inputs[1]["past_key_values"] = fake_pkv

if input_type == "np":
decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id}
for extra_inps in extra_inputs:
features = feature_extractor(data, return_tensors="pt")
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}

onnx_outputs = onnx_model(**features, **decoder_inputs)
with torch.no_grad():
transformers_outputs = transformers_model(**features, **decoder_inputs, **extra_inps)
for input_type in ["pt", "np"]:
features = feature_extractor(data, return_tensors=input_type)

self.assertTrue("logits" in onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])
if input_type == "np":
decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id}

# Compare tensor outputs
self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3))
if "past_key_values" in extra_inps:
del extra_inps["past_key_values"] # test only with pytorch

onnx_outputs = onnx_model(**features, **decoder_inputs, **extra_inps)

self.assertTrue("logits" in onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])

if use_cache:
self.assertEqual(
len(onnx_outputs["past_key_values"]), len(transformers_outputs["past_key_values"])
)
self.assertEqual(
len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0])
)
for i, _ in enumerate(onnx_outputs["past_key_values"]):
for j, ort_pkv in enumerate(onnx_outputs["past_key_values"][i]):
trfs_pkv = transformers_outputs["past_key_values"][i][j]
self.assertTrue(
torch.allclose(ort_pkv, trfs_pkv, atol=1e-3),
f" Maxdiff: {torch.abs(ort_pkv - trfs_pkv).max()}",
)

# Compare tensor outputs
self.assertTrue(
torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3)
)

gc.collect()

Expand Down
Loading