Skip to content

Commit

Permalink
include suggestions from review
Browse files Browse the repository at this point in the history
Co-Authored-By: Ross Wightman <[email protected]>
  • Loading branch information
a-r-r-o-w and rwightman committed Oct 30, 2023
1 parent 5f14bdd commit d5f1525
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
8 changes: 3 additions & 5 deletions timm/layers/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import functools
import types
from typing import Tuple, Union
from typing import Callable, Tuple, Type, Union

import torch.nn
import torch


LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
LayerType = Union[str, Callable, Type[torch.nn.Module]]
PadType = Union[str, int, Tuple[int, int]]
9 changes: 4 additions & 5 deletions timm/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.checkpoint import checkpoint

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
Expand Down Expand Up @@ -151,7 +150,7 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x: Tensor) -> Tensor:
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_stem(x)
x = self.bn1(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
Expand All @@ -160,7 +159,7 @@ def forward_features(self, x: Tensor) -> Tensor:
x = self.blocks(x)
return x

def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
Expand All @@ -171,7 +170,7 @@ def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
Expand Down Expand Up @@ -262,7 +261,7 @@ def __init__(
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable

def forward(self, x: Tensor) -> List[Tensor]:
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
Expand Down
11 changes: 5 additions & 6 deletions timm/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
Expand Down Expand Up @@ -112,7 +111,7 @@ def zero_init_last(self):
if getattr(self.bn2, 'weight', None) is not None:
nn.init.zeros_(self.bn2.weight)

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x

x = self.conv1(x)
Expand Down Expand Up @@ -212,7 +211,7 @@ def zero_init_last(self):
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x

x = self.conv1(x)
Expand Down Expand Up @@ -554,7 +553,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)

def forward_features(self, x: Tensor) -> Tensor:
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
Expand All @@ -569,13 +568,13 @@ def forward_features(self, x: Tensor) -> Tensor:
x = self.layer4(x)
return x

def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor:
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
return x if pre_logits else self.fc(x)

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.forward_features(x)
x = self.forward_head(x)
return x
Expand Down

0 comments on commit d5f1525

Please sign in to comment.