Skip to content

Commit

Permalink
MPT: Change order of operands to enable PT2 compile for inference (#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdoublep committed Aug 28, 2023
1 parent aad9f64 commit a8c7dc4
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,8 @@ def forward(
'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
)

if (attention_mask is not None and
attention_mask[:, 0].sum() != attention_mask.shape[0] and
self.training):
if (self.training and attention_mask is not None and
attention_mask[:, 0].sum() != attention_mask.shape[0]):
raise NotImplementedError(
'MPT does not support training with left padding.')

Expand Down

0 comments on commit a8c7dc4

Please sign in to comment.