diff --git a/llmfoundry/models/utils/bert_padding.py b/llmfoundry/models/utils/bert_padding.py index 2f465c37a1..02fe22a23a 100644 --- a/llmfoundry/models/utils/bert_padding.py +++ b/llmfoundry/models/utils/bert_padding.py @@ -161,6 +161,6 @@ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, """ output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange( - output, + output, #pyright: ignore[reportGeneralTypeIssues] '(b s) ... -> b s ...', #pyright: ignore[reportGeneralTypeIssues] b=batch)