Skip to content

Commit

Permalink
fix codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Sep 25, 2024
1 parent fadadc9 commit 9fa9e9f
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def codegen_wrapped_scaled_dot_product(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask`` in
# in this case, which is the later decoding steps, the `causal_mask` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -207,15 +207,20 @@ def codegen_wrapped_scaled_dot_product(
# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)

causal_mask = torch.where(causal_mask, 0, mask_value)
if not check_if_transformers_greater("4.44.99"):
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
causal_mask = torch.where(causal_mask, 0, mask_value)

# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
else:

attention_mask = attention_mask[:, :, :, : key.shape[-2]]

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
Expand All @@ -224,6 +229,7 @@ def codegen_wrapped_scaled_dot_product(
return sdpa_result, None



# Adapted from transformers.models.opt.modeling_opt.OPTAttention.forward
def opt_forward(
self,
Expand Down

0 comments on commit 9fa9e9f

Please sign in to comment.