From a8c7dc4786b5054247cd44dca452b2e06538eb6b Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 28 Aug 2023 20:39:29 +0200 Subject: [PATCH] MPT: Change order of operands to enable PT2 compile for inference (#559) --- llmfoundry/models/mpt/modeling_mpt.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 744886cbbd..5481201a8f 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -326,9 +326,8 @@ def forward( 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.' ) - if (attention_mask is not None and - attention_mask[:, 0].sum() != attention_mask.shape[0] and - self.training): + if (self.training and attention_mask is not None and + attention_mask[:, 0].sum() != attention_mask.shape[0]): raise NotImplementedError( 'MPT does not support training with left padding.')