Skip to content

Commit

Permalink
Migrate Type Hints for PEP 585 (mosaicml#3344)
Browse files Browse the repository at this point in the history
* v1

* v1 fix

* fix lint

* lint

* purge

* fix lint

* fix more lint

* fix import

* purge lint

* fix lint
  • Loading branch information
mvpatel2000 committed May 30, 2024
1 parent c9a51d4 commit dfbaf14
Show file tree
Hide file tree
Showing 154 changed files with 895 additions and 915 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Other high-level deep learning trainers provide simplicity at the cost of rigidi
Composer is built to automate away low-level pain points and headaches so you can focus on the important (and fun) parts of deep learning and iterate faster.

- [**Auto-resumption**](https://docs.mosaicml.com/projects/composer/en/stable/notes/resumption.html): Failed training run? Have no fear — just re-run your code, and Composer will automatically resume from your latest saved checkpoint.
- [**CUDA OOM Prevention**](https://docs.mosaicml.com/projects/composer/en/stable/examples/auto_microbatching.html): Say goodbye to out-of-memory errors. Set your microbatch size to “auto”, and Composer will automatically select the biggest one that fits on your GPUs.
- [**CUDA OOM Prevention**](https://docs.mosaicml.com/projects/composer/en/stable/examples/auto_microbatching.html): Say goodbye to out-of-memory errors. Set your microbatch size to “auto”, and Composer will automatically select the biggest one that fits on your GPUs.
- **[Time Abstractions](https://docs.mosaicml.com/projects/composer/en/latest/trainer/time.html):** Ever messed up your conversion between update steps, epochs, samples, and tokens? Specify your training duration with custom units (epochs, batches, samples, and tokens) in your training loop with our `Time` class.

## Integrations
Expand Down
10 changes: 5 additions & 5 deletions STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ The following rules apply to public APIs:
```python
from torch import Tensor
from typing import Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Union
from composer.utils import ensure_tuple
def foo(x: Optional[Union[Tensor, Sequence[Tensor]]]) -> Tuple[Tensor, ...]:
def foo(x: Optional[Union[Tensor, Sequence[Tensor]]]) -> tuple[Tensor, ...]:
return ensure_tuple(x) # ensures that the result is always a (potentially empty) tuple of tensors
```
Expand Down Expand Up @@ -281,7 +281,7 @@ For example, from [composer/callbacks/memory_monitor.py](composer/callbacks/memo
```python
"""Log memory usage during training."""
import logging
from typing import Dict, Union
from typing import Union
import torch.cuda
Expand Down Expand Up @@ -339,7 +339,7 @@ The following guidelines apply to documentation.
specify "optional", and the docstring should say the default value. Some examples:
```python
from typing import Optional, Tuple, Union
from typing import Optional, Union
def foo(bar: int):
"""Foo.
Expand Down Expand Up @@ -384,7 +384,7 @@ The following guidelines apply to documentation.
"""
...
def foo6(bar: int) -> Tuple[int, str]:
def foo6(bar: int) -> tuple[int, str]:
"""Foo6.
Args:
Expand Down
10 changes: 5 additions & 5 deletions composer/algorithms/alibi/attention_surgery_functions/_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
import math
from types import MethodType
from typing import Optional, Tuple
from typing import Optional

import torch
from torch import nn
Expand Down Expand Up @@ -61,9 +61,9 @@ def forward(
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
) -> tuple[torch.Tensor]:
"""Replication of identically-named attention function function ("forward") in Composer/HuggingFace BERT model's
BERTSelfAttention (:func:`transformers.models.bert.modeling_bert.BERTSelfAttention.forward`), but this function
implements ALiBi and will be used to replace the default attention function."""
Expand Down Expand Up @@ -95,10 +95,10 @@ def forward(
query_layer = self.transpose_for_scores(mixed_query_layer)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# if cross_attention save tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# if uni-directional self-attention (decoder) save tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-License-Identifier: Apache-2.0

from types import MethodType
from typing import Tuple

import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model
Expand Down Expand Up @@ -42,7 +41,7 @@ def gpt2_attention_converter(module: torch.nn.Module, module_index: int, max_seq
return module


def _attn(self, query, key, value, attention_mask=None, head_mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
def _attn(self, query, key, value, attention_mask=None, head_mask=None) -> tuple[torch.Tensor, torch.Tensor]:
"""Replication of identically-named attention function function ("_attn") in Composer/HuggingFace GPT2 model's
GPT2Attention (:func:`transformers.models.gpt2.modeling_gpt2.GPT2Attention._attn`; `GitHub link <https://\\
github.com/huggingface/transformers/blob/2e11a043374a6229ec129a4765ee4ba7517832b9/src/transformers/models/\\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import math
from operator import attrgetter
from typing import Callable, Dict, Optional, Type
from typing import Callable, Optional, Type

import torch

Expand All @@ -16,7 +16,7 @@
AlibiReplacementFunction = Callable[[torch.nn.Module, int, int], Optional[torch.nn.Module]]


class PolicyRegistry(Dict[Type[torch.nn.Module], AlibiReplacementFunction]):
class PolicyRegistry(dict[Type[torch.nn.Module], AlibiReplacementFunction]):
"""A registry mapping for ALiBi surgery."""

def register(
Expand Down
6 changes: 3 additions & 3 deletions composer/algorithms/augmix/augmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import functools
import textwrap
import weakref
from typing import List, TypeVar
from typing import TypeVar

import numpy as np
import torch
Expand All @@ -32,7 +32,7 @@ def augmix_image(
depth: int = -1,
width: int = 3,
alpha: float = 1.0,
augmentation_set: List = augmentation_sets['all'],
augmentation_set: list = augmentation_sets['all'],
) -> ImgT:
r"""Applies the AugMix (`Hendrycks et al, 2020 <http://arxiv.org/abs/1912.02781>`_) data augmentation.
Expand Down Expand Up @@ -77,7 +77,7 @@ def _augmix_pil_image(
depth: int,
width: int,
alpha: float,
augmentation_set: List,
augmentation_set: list,
) -> PillowImage:
chain_weights = np.random.dirichlet([alpha] * width).astype(np.float32)
mixing_weight = np.float32(np.random.beta(alpha, alpha))
Expand Down
28 changes: 14 additions & 14 deletions composer/algorithms/colout/colout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import textwrap
import weakref
from typing import Any, Callable, Tuple, TypeVar, Union
from typing import Any, Callable, TypeVar, Union

import torch
import torch.utils.data
Expand All @@ -29,11 +29,11 @@


def colout_batch(
sample: Union[ImgT, Tuple[ImgT, ImgT]],
sample: Union[ImgT, tuple[ImgT, ImgT]],
p_row: float = 0.15,
p_col: float = 0.15,
resize_target: Union[bool, str] = 'auto',
) -> Union[torch.Tensor, ImgT, Tuple[Tensor, Tensor], Tuple[ImgT, ImgT]]:
) -> Union[torch.Tensor, ImgT, tuple[Tensor, Tensor], tuple[ImgT, ImgT]]:
"""Applies ColOut augmentation to a batch of images and (optionally) targets,
dropping the same random rows and columns from all images and targets in a batch.
Expand All @@ -46,7 +46,7 @@ def colout_batch(
new_X = colout_batch(X_example, p_row=0.15, p_col=0.15)
Args:
sample (torch.Tensor | PIL.Image | Tuple[torch.Tensor, torch.Tensor] | Tuple[PIL.Image, PIL.Image]):
sample (torch.Tensor | PIL.Image | tuple[torch.Tensor, torch.Tensor] | tuple[PIL.Image, PIL.Image]):
Either a single tensor or image or a 2-tuple of tensors or images. When tensor(s), the tensor must be of shape
``CHW`` for a single image or ``NCHW`` for a batch of images of shape.
p_row (float, optional): Fraction of rows to drop (drop along H). Default: ``0.15``.
Expand All @@ -56,7 +56,7 @@ def colout_batch(
Otherwise, only the first object is resized. Default: ``'auto'``.
Returns:
torch.Tensor | PIL.Image | Tuple[torch.Tensor, torch.Tensor] | Tuple[PIL.Image, PIL.Image]:
torch.Tensor | PIL.Image | tuple[torch.Tensor, torch.Tensor] | tuple[PIL.Image, PIL.Image]:
A smaller image or 2-tuple of images with random rows and columns dropped.
"""

Expand Down Expand Up @@ -139,16 +139,16 @@ def __init__(self, p_row: float = 0.15, p_col: float = 0.15, resize_target: Unio

def __call__(
self,
sample: Union[ImgT, Tuple[ImgT, ImgT]],
) -> Union[torch.Tensor, ImgT, Tuple[Tensor, Tensor], Tuple[ImgT, ImgT]]:
sample: Union[ImgT, tuple[ImgT, ImgT]],
) -> Union[torch.Tensor, ImgT, tuple[Tensor, Tensor], tuple[ImgT, ImgT]]:
"""Drops random rows and columns from up to two images.
Args:
sample (torch.Tensor | PIL.Image | Tuple[torch.Tensor, torch.Tensor] | Tuple[PIL.Image, PIL.Image]):
sample (torch.Tensor | PIL.Image | tuple[torch.Tensor, torch.Tensor] | tuple[PIL.Image, PIL.Image]):
A single image or a 2-tuple of images as either :class:`torch.Tensor` or :class:`PIL.Image`.
Returns:
torch.Tensor | PIL.Image | Tuple[torch.Tensor, torch.Tensor] | Tuple[PIL.Image, PIL.Image]:
torch.Tensor | PIL.Image | tuple[torch.Tensor, torch.Tensor] | tuple[PIL.Image, PIL.Image]:
A smaller image or 2-tuple of images with random rows and columns dropped.
"""

Expand Down Expand Up @@ -193,11 +193,11 @@ class ColOut(Algorithm):
batch (bool, optional): Run ColOut at the batch level. Default: ``True``.
resize_target (bool | str, optional): Whether to resize the target in addition to the input. If set to ``'auto'``, resizing
the target will be based on if the target has the same spatial dimensions as the input. Default: ``auto``.
input_key (str | int | Tuple[Callable, Callable] | Any, optional): A key that indexes to the input
input_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the input
from the batch. Can also be a pair of get and set functions, where the getter
is assumed to be first in the pair. The default is 0, which corresponds to any sequence, where the first element
is the input. Default: ``0``.
target_key (str | int | Tuple[Callable, Callable] | Any, optional): A key that indexes to the target
target_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the target
from the batch. Can also be a pair of get and set functions, where the getter
is assumed to be first in the pair. The default is 1, which corresponds to any sequence, where the second element
is the target. Default: ``1``.
Expand All @@ -209,8 +209,8 @@ def __init__(
p_col: float = 0.15,
batch: bool = True,
resize_target: Union[bool, str] = 'auto',
input_key: Union[str, int, Tuple[Callable, Callable], Any] = 0,
target_key: Union[str, int, Tuple[Callable, Callable], Any] = 1,
input_key: Union[str, int, tuple[Callable, Callable], Any] = 0,
target_key: Union[str, int, tuple[Callable, Callable], Any] = 1,
):
if not (0 <= p_col <= 1):
raise ValueError('p_col must be between 0 and 1')
Expand Down Expand Up @@ -286,7 +286,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> None:
self._apply_sample(state)


def _should_resize_target(sample: Union[ImgT, Tuple[ImgT, ImgT]], resize_target: Union[bool, str]) -> bool:
def _should_resize_target(sample: Union[ImgT, tuple[ImgT, ImgT]], resize_target: Union[bool, str]) -> bool:
"""Helper function to determine if both objects in the tuple should be resized.
Decision is based on ``resize_target`` and if both objects in the tuple have the same spatial size.
Expand Down
20 changes: 10 additions & 10 deletions composer/algorithms/cutmix/cutmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import logging
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union

import numpy as np
import torch
Expand All @@ -26,10 +26,10 @@ def cutmix_batch(
target: Tensor,
length: Optional[float] = None,
alpha: float = 1.,
bbox: Optional[Tuple] = None,
bbox: Optional[tuple] = None,
indices: Optional[torch.Tensor] = None,
uniform_sampling: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, float, Tuple]:
) -> tuple[torch.Tensor, torch.Tensor, float, tuple]:
"""Create new samples using combinations of pairs of samples.
This is done by masking a region of each image in ``input`` and filling
Expand Down Expand Up @@ -169,11 +169,11 @@ class CutMix(Algorithm):
box such that each pixel has an equal probability of being mixed.
If ``False``, defaults to the sampling used in the original
paper implementation. Default: ``False``.
input_key (str | int | Tuple[Callable, Callable] | Any, optional): A key that indexes to the input
input_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the input
from the batch. Can also be a pair of get and set functions, where the getter
is assumed to be first in the pair. The default is 0, which corresponds to any sequence, where the first element
is the input. Default: ``0``.
target_key (str | int | Tuple[Callable, Callable] | Any, optional): A key that indexes to the target
target_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the target
from the batch. Can also be a pair of get and set functions, where the getter
is assumed to be first in the pair. The default is 1, which corresponds to any sequence, where the second element
is the target. Default: ``1``.
Expand All @@ -198,16 +198,16 @@ def __init__(
alpha: float = 1.,
interpolate_loss: bool = False,
uniform_sampling: bool = False,
input_key: Union[str, int, Tuple[Callable, Callable], Any] = 0,
target_key: Union[str, int, Tuple[Callable, Callable], Any] = 1,
input_key: Union[str, int, tuple[Callable, Callable], Any] = 0,
target_key: Union[str, int, tuple[Callable, Callable], Any] = 1,
):
self.alpha = alpha
self.interpolate_loss = interpolate_loss
self._uniform_sampling = uniform_sampling

self._indices = torch.Tensor()
self._cutmix_lambda = 0.0
self._bbox: Tuple[int, int, int, int] = (0, 0, 0, 0)
self._bbox: tuple[int, int, int, int] = (0, 0, 0, 0)
self._permuted_target = torch.Tensor()
self._adjusted_lambda = 0.0
self.input_key, self.target_key = input_key, target_key
Expand Down Expand Up @@ -339,7 +339,7 @@ def _rand_bbox(
cx: Optional[int] = None,
cy: Optional[int] = None,
uniform_sampling: bool = False,
) -> Tuple[int, int, int, int]:
) -> tuple[int, int, int, int]:
"""Randomly samples a bounding box with area determined by ``cutmix_lambda``.
Adapted from original implementation https://github.com/clovaai/CutMix-PyTorch
Expand Down Expand Up @@ -385,7 +385,7 @@ def _rand_bbox(
return bbx1, bby1, bbx2, bby2


def _adjust_lambda(x: Tensor, bbox: Tuple) -> float:
def _adjust_lambda(x: Tensor, bbox: tuple) -> float:
"""Rescale the cutmix lambda according to the size of the clipped bounding box.
Args:
Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/cutout/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class CutOut(Algorithm):
box such that each pixel has an equal probability of being masked.
If ``False``, defaults to the sampling used in the original paper
implementation. Default: ``False``.
input_key (str | int | Tuple[Callable, Callable] | Any, optional): A key that indexes to the input
input_key (str | int | tuple[Callable, Callable] | Any, optional): A key that indexes to the input
from the batch. Can also be a pair of get and set functions, where the getter
is assumed to be first in the pair. The default is 0, which corresponds to any sequence, where the first element
is the input. Default: ``0``.
Expand Down
8 changes: 4 additions & 4 deletions composer/algorithms/ema/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import contextlib
import itertools
import logging
from typing import Any, Dict, Optional, Union
from typing import Any, Optional, Union

import torch

Expand Down Expand Up @@ -304,7 +304,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> None:
# Swap the training model out for the ema model for checkpointing
self._ensure_ema_weights_active(state)

def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> dict[str, Any]:
state_dict = super().state_dict()
for attribute_name in self.serialized_attributes:
if attribute_name == 'ema_model':
Expand All @@ -316,7 +316,7 @@ def state_dict(self) -> Dict[str, Any]:
state_dict[attribute_name] = getattr(self, attribute_name)
return state_dict

def ensure_compatible_state_dict(self, state: Dict[str, Any]):
def ensure_compatible_state_dict(self, state: dict[str, Any]):
"""Ensure state dicts created prior to Composer 0.13.0 are compatible with later versions."""
# Version 0.13.0 and later state dicts will not include training_model.
if 'training_model' not in state:
Expand Down Expand Up @@ -351,7 +351,7 @@ def ensure_compatible_state_dict(self, state: Dict[str, Any]):
state['ema_model']['named_buffers_dict'] = named_buffers_dict
return state

def load_state_dict(self, state: Dict[str, Any], strict: bool = False):
def load_state_dict(self, state: dict[str, Any], strict: bool = False):
state_dict = self.ensure_compatible_state_dict(state)
for attribute_name, serialized_value in state_dict.items():
if attribute_name != 'repr': # skip attribute added by parent class
Expand Down
4 changes: 2 additions & 2 deletions composer/algorithms/factorize/factorize_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import dataclasses
from typing import Optional, Tuple, Union
from typing import Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -65,7 +65,7 @@ def _nmse(Y: torch.Tensor, Y_hat: torch.Tensor) -> float:
return float((diffs * diffs).mean() / Y.var())


def _svd_initialize(Wa: torch.Tensor, Wb: Optional[torch.Tensor], k: int) -> Tuple[torch.Tensor, torch.Tensor]:
def _svd_initialize(Wa: torch.Tensor, Wb: Optional[torch.Tensor], k: int) -> tuple[torch.Tensor, torch.Tensor]:
if Wb is None:
W = Wa
else:
Expand Down
Loading

0 comments on commit dfbaf14

Please sign in to comment.