diff --git a/llmfoundry/models/utils/bert_padding.py b/llmfoundry/models/utils/bert_padding.py index c9ddb4ab80..725aced703 100644 --- a/llmfoundry/models/utils/bert_padding.py +++ b/llmfoundry/models/utils/bert_padding.py @@ -13,7 +13,7 @@ in `bert_layers.py`. """ -from typing import Tuple, cast +from typing import Tuple, cast, Any import torch import torch.nn.functional as F @@ -23,7 +23,7 @@ class IndexFirstAxis(torch.autograd.Function): @staticmethod - def forward(ctx, input: torch.Tensor, + def forward(ctx: Any, input: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """Get just the values of `input` which are at `indices`. @@ -47,7 +47,7 @@ def forward(ctx, input: torch.Tensor, ).reshape(-1, *other_shape) # (num_idx, ...) @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: indices, = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] @@ -69,8 +69,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: class IndexPutFirstAxis(torch.autograd.Function): @staticmethod - def forward(ctx, values: torch.Tensor, indices: torch.Tensor, - first_axis_dim) -> torch.Tensor: + def forward(ctx: Any, values: torch.Tensor, indices: torch.Tensor, + first_axis_dim: int) -> torch.Tensor: ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 @@ -82,7 +82,7 @@ def forward(ctx, values: torch.Tensor, indices: torch.Tensor, return output @staticmethod - def backward(ctx, + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: indices, = ctx.saved_tensors grad_values = grad_output[indices] @@ -142,7 +142,7 @@ def unpad_input_only( """ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() return index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), - indices) + indices) #pyright: ignore[reportGeneralTypeIssues] def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, @@ -159,4 +159,4 @@ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, hidden_states: (batch, seqlen, ...) """ output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, '(b s) ... -> b s ...', b=batch) + return rearrange(output, '(b s) ... -> b s ...', b=batch) #pyright: ignore[reportGeneralTypeIssues]