Skip to content

Commit

Permalink
pyright cleanup bert_padding.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobfulano committed Aug 21, 2023
1 parent 69e4e41 commit ade08b8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions llmfoundry/models/utils/bert_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
in `bert_layers.py`.
"""

from typing import Tuple, cast
from typing import Tuple, cast, Any

import torch
import torch.nn.functional as F
Expand All @@ -23,7 +23,7 @@
class IndexFirstAxis(torch.autograd.Function):

@staticmethod
def forward(ctx, input: torch.Tensor,
def forward(ctx: Any, input: torch.Tensor,
indices: torch.Tensor) -> torch.Tensor:
"""Get just the values of `input` which are at `indices`.
Expand All @@ -47,7 +47,7 @@ def forward(ctx, input: torch.Tensor,
).reshape(-1, *other_shape) # (num_idx, ...)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
indices, = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
Expand All @@ -69,8 +69,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
class IndexPutFirstAxis(torch.autograd.Function):

@staticmethod
def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
first_axis_dim) -> torch.Tensor:
def forward(ctx: Any, values: torch.Tensor, indices: torch.Tensor,
first_axis_dim: int) -> torch.Tensor:
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
Expand All @@ -82,7 +82,7 @@ def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
return output

@staticmethod
def backward(ctx,
def backward(ctx: Any,
grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
indices, = ctx.saved_tensors
grad_values = grad_output[indices]
Expand Down Expand Up @@ -142,7 +142,7 @@ def unpad_input_only(
"""
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
return index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
indices)
indices) #pyright: ignore[reportGeneralTypeIssues]


def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
Expand All @@ -159,4 +159,4 @@ def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
hidden_states: (batch, seqlen, ...)
"""
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, '(b s) ... -> b s ...', b=batch)
return rearrange(output, '(b s) ... -> b s ...', b=batch) #pyright: ignore[reportGeneralTypeIssues]

0 comments on commit ade08b8

Please sign in to comment.