Skip to content

Commit

Permalink
[Bug] fix dsv2 attn dispatch (softmax_scale) (#873)
Browse files Browse the repository at this point in the history
fix dsv2 attn dispatch (softmax_scale)
  • Loading branch information
HIT-cwh authored Jul 31, 2024
1 parent d2a173a commit 01640b0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xtuner/model/modules/dispatch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def flash_attn_w_mask(
key_states,
value_states,
attention_mask,
softmax_scale=None,
causal=True,
dropout_p=0.0,
window_size=(-1, -1), # -1 means infinite context window
Expand All @@ -57,6 +58,7 @@ def flash_attn_w_mask(
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
softmax_scale=softmax_scale,
dropout_p=dropout_p,
causal=causal,
window_size=window_size)
Expand All @@ -71,6 +73,7 @@ def varlen_flash_attn(
value_states,
cumulative_len,
max_seqlen,
softmax_scale=None,
dropout_p=0.,
causal=True,
window_size=(-1, -1), # -1 means infinite context window
Expand All @@ -85,6 +88,7 @@ def varlen_flash_attn(
cumulative_len,
max_seqlen,
max_seqlen,
softmax_scale=softmax_scale,
dropout_p=dropout_p,
return_attn_probs=False,
causal=causal,
Expand Down
2 changes: 2 additions & 0 deletions xtuner/model/modules/dispatch/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def deepseek_varlen_attn_forward(
value_states,
cumulative_len,
max_seqlen,
softmax_scale=self.softmax_scale,
causal=causal,
dropout_p=dropout_rate,
training=True)
Expand All @@ -287,6 +288,7 @@ def deepseek_varlen_attn_forward(
query_states,
key_states,
value_states,
softmax_scale=self.softmax_scale,
causal=causal,
dropout_p=dropout_rate,
training=False)
Expand Down

0 comments on commit 01640b0

Please sign in to comment.