Skip to content

Commit

Permalink
Allow MPT models to return attention weights (#599)
Browse files Browse the repository at this point in the history
* Allow MPT models to return attention weights

* Update llmfoundry/models/mpt/modeling_mpt.py

Co-authored-by: Daniel King <[email protected]>

* Add unit test

* Update tests/test_model.py

Co-authored-by: Daniel King <[email protected]>

* Update tests/test_model.py

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
lorabit110 and dakinggg committed Sep 21, 2023
1 parent 299e737 commit 0be2ca8
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def forward(
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
Expand All @@ -100,6 +101,7 @@ def forward(
attn_bias=attn_bias,
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
)
x = x + self.resid_attn_dropout(b)
m = x
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def forward(
attn_bias=attn_bias,
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
)
if past_key_values is not None:
past_key_values[b_idx] = past_key_value
Expand Down
2 changes: 2 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,8 @@ def test_forward_with_output_attentions_and_output_hidden_states(

if output_attentions:
assert len(outputs.attentions) == n_layers
assert all(
attn.shape == (1, 4, 3, 3) for attn in outputs.attentions)
if output_hidden_states:
assert len(outputs.hidden_states) == n_layers + 1

Expand Down

0 comments on commit 0be2ca8

Please sign in to comment.