Skip to content

Commit

Permalink
Add type annotation when __init__ doesn't take any argument (mosaicml…
Browse files Browse the repository at this point in the history
…#3347)

According to [mypy documentation]: https://mypy.readthedocs.io/en/stable/class_basics.html#annotating-init-methods
> if __init__ has no annotated arguments and no return type annotation, it is considered an untyped method

Co-authored-by: Alex Ghelfi <[email protected]>
  • Loading branch information
antoinebrl and Ghelfi committed May 31, 2024
1 parent 3c0a817 commit 3241a85
Show file tree
Hide file tree
Showing 13 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion composer/algorithms/channels_last/channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ChannelsLast(Algorithm):
)
"""

def __init__(self):
def __init__(self) -> None:
# ChannelsLast takes no arguments
pass

Expand Down
2 changes: 1 addition & 1 deletion composer/algorithms/no_op_model/no_op_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
class NoOpModel(Algorithm):
"""Runs on :attr:`Event.INIT` and replaces the model with a dummy :class:`.NoOpModelClass` instance."""

def __init__(self):
def __init__(self) -> None:
# No arguments
pass

Expand Down
2 changes: 1 addition & 1 deletion composer/devices/device_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class DeviceMPS(Device):
dist_backend = ''
name = 'mps'

def __init__(self):
def __init__(self) -> None:
if version.parse(torch.__version__) < version.parse('1.12.0'):
raise RuntimeError('Support for MPS device requires torch >= 1.12.')
if not torch.backends.mps.is_available(): # type: ignore (version guarded)
Expand Down
2 changes: 1 addition & 1 deletion composer/devices/device_neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DeviceNeuron(Device):
name = 'neuron'
dist_backend = 'xla'

def __init__(self):
def __init__(self) -> None:
import torch_xla.core.xla_model as xm

# Turn off compiler based mixed precision (we use torch's amp support)
Expand Down
2 changes: 1 addition & 1 deletion composer/devices/device_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DeviceTPU(Device):
dist_backend = 'xla'
name = 'tpu'

def __init__(self):
def __init__(self) -> None:
import torch_xla.core.xla_model as xm

self._device = xm.xla_device()
Expand Down
2 changes: 1 addition & 1 deletion composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def is_initialized():
return dist.is_initialized()


def initialize_dist(device: Union[str, Device], timeout: float = 300.0):
def initialize_dist(device: Union[str, Device], timeout: float = 300.0) -> None:
"""Initialize the default PyTorch distributed process group.
This function assumes that the following environment variables are set:
Expand Down
2 changes: 1 addition & 1 deletion tests/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def __init__(self, vocab_size: int = 10, num_classes: int = 2):
class ConvModel(ComposerClassifier):
"""Convolutional network featuring strided convs, a batchnorm, max pooling, and average pooling."""

def __init__(self):
def __init__(self) -> None:
conv_args = {'kernel_size': (3, 3), 'padding': 1}
conv1 = torch.nn.Conv2d(in_channels=32, out_channels=8, stride=2, bias=False, **conv_args) # stride > 1
conv2 = torch.nn.Conv2d(
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_ddp_sync_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class MinimalConditionalModel(nn.Module):

def __init__(self):
def __init__(self) -> None:
super().__init__()

self.choice1 = nn.Linear(1, 1, bias=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ def test_infinite_eval_dataloader(eval_subset_num_batches, success):

class BreakBatchAlgorithm(Algorithm):

def __init__(self):
def __init__(self) -> None:
super().__init__()

def match(self, event, state):
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_autolog_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ def test_extract_hparams():

class Foo:

def __init__(self):
def __init__(self) -> None:
self.g = 7

class Bar:

def __init__(self):
def __init__(self) -> None:
self.local_hparams = {'m': 11}

class Baz(StringEnum):
Expand Down
10 changes: 5 additions & 5 deletions tests/utils/test_fx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class MyTestModel(nn.Module):

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.relu = nn.ReLU()
self.factor = 0.5
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_replace_op(model_cls, src_ops, tgt_op, count):

class SimpleParallelLinears(nn.Module):

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(64, 64)
self.fc2 = nn.Linear(64, 64)
Expand All @@ -81,7 +81,7 @@ def forward(self, x):

class ParallelLinears(nn.Module):

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(64, 64)
self.ln = nn.LayerNorm(64)
Expand All @@ -98,7 +98,7 @@ def forward(self, x):

class NotFusibleLinears(nn.Module):

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(64, 64, bias=False)
self.ln = nn.LayerNorm(64)
Expand All @@ -115,7 +115,7 @@ def forward(self, x):

class NotParallelLinears(nn.Module):

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(64, 64)
self.ln = nn.LayerNorm(64)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def test_export_with_other_logger(model_cls, dataloader):

class LinModel(nn.Module):

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.lin1 = nn.Linear(256, 128)
self.lin2 = nn.Linear(128, 256)
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_module_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, in_features: int, out_features: int):
class SimpleReplacementPolicy(nn.Module):
"""Bundle the model, replacement function, and validation into one class."""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.fc1 = nn.Linear(in_features=16, out_features=32)
self.fc2 = nn.Linear(in_features=32, out_features=10)
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_params_kept(optimizer_surgery_state):

class ParamTestModel(nn.Module):

def __init__(self):
def __init__(self) -> None:
super().__init__()

self.fc1 = nn.Linear(8, 8)
Expand Down

0 comments on commit 3241a85

Please sign in to comment.