diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 4f3b9c895f..59a21f944d 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -433,7 +433,13 @@ def __init__( ): super().__init__(session, parent_model) - if self.parent_model.use_merged is False and self.use_past is True: + # We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2 + # can be used but do not support KV caching for the cross-attention key/values, see: + # https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/gpt2/modeling_gpt2.py#L302-L311 + # This attribute is used to avoid returning cross-attention KV-cache in this case. + self.no_cross_attention_cache = getattr(self.parent_model, "no_cross_attention_cache", False) + + if (not self.parent_model.use_merged and self.use_past) or self.no_cross_attention_cache: self.num_pkv = 2 else: # When using a merged model, we always have the same number of output whether we use past key values or not, @@ -688,7 +694,7 @@ def forward( # Tuple of tuple of length `n_layers`, with each tuple of length equal to: # * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention) # * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant) - if self.use_past is False or use_merged_no_cache: + if not self.use_past or use_merged_no_cache or self.no_cross_attention_cache: out_past_key_values = tuple( out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv) ) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index cdd8d1b6cd..ee09713390 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -1244,6 +1244,11 @@ def __init__( generation_config: Optional[GenerationConfig] = None, **kwargs, ): + # There are probably other archs that do not support cross attention KV cache, but only + # this one seem popular on the Hub. + if config.decoder.model_type == "gpt2": + self.no_cross_attention_cache = True + super().__init__( encoder_session, decoder_session, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 0ba847a0d6..6ffbbb7732 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4023,28 +4023,56 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach feature_extractor, tokenizer = self._get_preprocessors(model_id) data = self._get_sample_image() - features = feature_extractor(data, return_tensors="pt") start_token = "" decoder_start_token_id = tokenizer.encode(start_token)[0] - decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} - with torch.no_grad(): - transformers_outputs = transformers_model(**features, **decoder_inputs) + extra_inputs = [{}, {}] - for input_type in ["pt", "np"]: - features = feature_extractor(data, return_tensors=input_type) + if use_cache and False: + # TODO: the dims will fail with other models + fake_pkv = tuple((torch.rand(1, 4, 1, 8), torch.rand(1, 4, 1, 8)) for _ in range(5)) + extra_inputs[1]["past_key_values"] = fake_pkv - if input_type == "np": - decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id} + for extra_inps in extra_inputs: + features = feature_extractor(data, return_tensors="pt") + decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} - onnx_outputs = onnx_model(**features, **decoder_inputs) + with torch.no_grad(): + transformers_outputs = transformers_model(**features, **decoder_inputs, **extra_inps) + for input_type in ["pt", "np"]: + features = feature_extractor(data, return_tensors=input_type) - self.assertTrue("logits" in onnx_outputs) - self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + if input_type == "np": + decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id} - # Compare tensor outputs - self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3)) + if "past_key_values" in extra_inps: + del extra_inps["past_key_values"] # test only with pytorch + + onnx_outputs = onnx_model(**features, **decoder_inputs, **extra_inps) + + self.assertTrue("logits" in onnx_outputs) + self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) + + if use_cache: + self.assertEqual( + len(onnx_outputs["past_key_values"]), len(transformers_outputs["past_key_values"]) + ) + self.assertEqual( + len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0]) + ) + for i, _ in enumerate(onnx_outputs["past_key_values"]): + for j, ort_pkv in enumerate(onnx_outputs["past_key_values"][i]): + trfs_pkv = transformers_outputs["past_key_values"][i][j] + self.assertTrue( + torch.allclose(ort_pkv, trfs_pkv, atol=1e-3), + f" Maxdiff: {torch.abs(ort_pkv - trfs_pkv).max()}", + ) + + # Compare tensor outputs + self.assertTrue( + torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-3) + ) gc.collect()