Skip to content

Commit

Permalink
fix scaled_dot_product_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Sep 21, 2023
1 parent 1bf05bf commit b6cca22
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mmdeploy/pytorch/functions/multi_head_attention_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,29 @@ def _scaled_dot_product_attention__tensorrt(q: Tensor,
**kwargs) -> Tuple[Tensor, Tensor]:
"""Rewrite for custom ops."""
return ScaledDotProductAttentionTRT.apply(q, k, v, attn_mask)


@FUNCTION_REWRITER.register_rewriter(
func_name='torch.nn.functional.scaled_dot_product_attention',
backend=Backend.DEFAULT.value)
def scaled_dot_product_attention__default(query,
key,
value,
attn_mask=None,
dropout_p=0.,
scale=None,
is_causal=False):
"""Rewrite to export to onnx on torch>=2.0.0."""
scale = scale or query.size(-1)**0.5

Check warning on line 69 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L69

Added line #L69 was not covered by tests
if is_causal and attn_mask is not None:
attn_mask = torch.ones(

Check warning on line 71 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L71

Added line #L71 was not covered by tests
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))

Check warning on line 74 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L74

Added line #L74 was not covered by tests

attn_weight = query @ key.transpose(-2, -1) / scale

Check warning on line 76 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L76

Added line #L76 was not covered by tests
if attn_mask is not None:
attn_weight += attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
return attn_weight @ value

Check warning on line 81 in mmdeploy/pytorch/functions/multi_head_attention_forward.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/multi_head_attention_forward.py#L78-L81

Added lines #L78 - L81 were not covered by tests

0 comments on commit b6cca22

Please sign in to comment.