diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 744886cbbd..5481201a8f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -326,9 +326,8 @@ def forward( 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.' ) - if (attention_mask is not None and - attention_mask[:, 0].sum() != attention_mask.shape[0] and - self.training): + if (self.training and attention_mask is not None and + attention_mask[:, 0].sum() != attention_mask.shape[0]): raise NotImplementedError( 'MPT does not support training with left padding.')