From 9fa9e9fc3ebd4ca719ac88aa06942cf6ae1da2d8 Mon Sep 17 00:00:00 2001 From: Ella Charlaix Date: Wed, 25 Sep 2024 19:01:54 +0200 Subject: [PATCH] fix codegen --- optimum/bettertransformer/models/attention.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 9dfa57844d..53e6a676e6 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -195,7 +195,7 @@ def codegen_wrapped_scaled_dot_product( query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True ) else: - # in this case, which is the later decoding steps, the `causal_mask`` in + # in this case, which is the later decoding steps, the `causal_mask` in # https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195 # is [True, ..., True] so actually not causal sdpa_result = torch.nn.functional.scaled_dot_product_attention( @@ -207,15 +207,20 @@ def codegen_wrapped_scaled_dot_product( # causal_mask is always [True, ..., True] otherwise, so executing this # is unnecessary if query_length > 1: - causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - causal_mask = torch.where(causal_mask, 0, mask_value) + if not check_if_transformers_greater("4.44.99"): + causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - # torch.Tensor.expand does no memory copy - causal_mask = causal_mask.expand(batch_size, -1, -1, -1) + causal_mask = torch.where(causal_mask, 0, mask_value) - # we use torch.min to avoid having tensor(-inf) - attention_mask = torch.min(causal_mask, attention_mask) + # torch.Tensor.expand does no memory copy + causal_mask = causal_mask.expand(batch_size, -1, -1, -1) + + # we use torch.min to avoid having tensor(-inf) + attention_mask = torch.min(causal_mask, attention_mask) + else: + + attention_mask = attention_mask[:, :, :, : key.shape[-2]] sdpa_result = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False @@ -224,6 +229,7 @@ def codegen_wrapped_scaled_dot_product( return sdpa_result, None + # Adapted from transformers.models.opt.modeling_opt.OPTAttention.forward def opt_forward( self,