Skip to content

Commit

Permalink
inception_next dilation support, weights on hf hub, classifier reset …
Browse files Browse the repository at this point in the history
…/ global pool / no head fixes
  • Loading branch information
rwightman committed Aug 24, 2023
1 parent 2e03b01 commit ef2e6e6
Showing 1 changed file with 74 additions and 28 deletions.
102 changes: 74 additions & 28 deletions timm/models/inception_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.nn as nn

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, to_2tuple
from timm.layers import trunc_normal_, DropPath, to_2tuple, create_conv2d, get_padding, SelectAdaptivePool2d
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs
Expand All @@ -23,16 +23,23 @@ def __init__(
in_chs,
square_kernel_size=3,
band_kernel_size=11,
branch_ratio=0.125
branch_ratio=0.125,
dilation=1,
):
super().__init__()

gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch
self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc)
square_padding = get_padding(square_kernel_size, dilation=dilation)
band_padding = get_padding(band_kernel_size, dilation=dilation)
self.dwconv_hw = nn.Conv2d(
gc, gc, square_kernel_size,
padding=square_padding, dilation=dilation, groups=gc)
self.dwconv_w = nn.Conv2d(
gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc)
gc, gc, (1, band_kernel_size),
padding=(0, band_padding), dilation=(1, dilation), groups=gc)
self.dwconv_h = nn.Conv2d(
gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc)
gc, gc, (band_kernel_size, 1),
padding=(band_padding, 0), dilation=(dilation, 1), groups=gc)
self.split_indexes = (in_chs - 3 * gc, gc, gc, gc)

def forward(self, x):
Expand Down Expand Up @@ -89,22 +96,25 @@ def __init__(
self,
dim,
num_classes=1000,
pool_type='avg',
mlp_ratio=3,
act_layer=nn.GELU,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop=0.,
bias=True
):
super().__init__()
hidden_features = int(mlp_ratio * dim)
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True)
in_features = dim * self.global_pool.feat_mult()
hidden_features = int(mlp_ratio * in_features)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.norm = norm_layer(hidden_features)
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = x.mean((2, 3)) # global average pooling
x = self.global_pool(x)
x = self.fc1(x)
x = self.act(x)
x = self.norm(x)
Expand All @@ -124,7 +134,8 @@ class MetaNeXtBlock(nn.Module):
def __init__(
self,
dim,
token_mixer=nn.Identity,
dilation=1,
token_mixer=InceptionDWConv2d,
norm_layer=nn.BatchNorm2d,
mlp_layer=ConvMlp,
mlp_ratio=4,
Expand All @@ -134,7 +145,7 @@ def __init__(

):
super().__init__()
self.token_mixer = token_mixer(dim)
self.token_mixer = token_mixer(dim, dilation=dilation)
self.norm = norm_layer(dim)
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None
Expand All @@ -156,21 +167,28 @@ def __init__(
self,
in_chs,
out_chs,
ds_stride=2,
stride=2,
depth=2,
dilation=(1, 1),
drop_path_rates=None,
ls_init_value=1.0,
token_mixer=nn.Identity,
token_mixer=InceptionDWConv2d,
act_layer=nn.GELU,
norm_layer=None,
mlp_ratio=4,
):
super().__init__()
self.grad_checkpointing = False
if ds_stride > 1:
if stride > 1 or dilation[0] != dilation[1]:
self.downsample = nn.Sequential(
norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=ds_stride, stride=ds_stride),
nn.Conv2d(
in_chs,
out_chs,
kernel_size=2,
stride=stride,
dilation=dilation[0],
),
)
else:
self.downsample = nn.Identity()
Expand All @@ -180,6 +198,7 @@ def __init__(
for i in range(depth):
stage_blocks.append(MetaNeXtBlock(
dim=out_chs,
dilation=dilation[1],
drop_path=drop_path_rates[i],
ls_init_value=ls_init_value,
token_mixer=token_mixer,
Expand Down Expand Up @@ -221,10 +240,11 @@ def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
output_stride=32,
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 768),
token_mixers=nn.Identity,
token_mixers=InceptionDWConv2d,
norm_layer=nn.BatchNorm2d,
act_layer=nn.GELU,
mlp_ratios=(4, 4, 4, 3),
Expand All @@ -241,6 +261,7 @@ def __init__(
if not isinstance(mlp_ratios, (list, tuple)):
mlp_ratios = [mlp_ratios] * num_stage
self.num_classes = num_classes
self.global_pool = global_pool
self.drop_rate = drop_rate
self.feature_info = []

Expand All @@ -266,7 +287,8 @@ def __init__(
self.stages.append(MetaNeXtStage(
prev_chs,
out_chs,
ds_stride=2 if i > 0 else 1,
stride=stride if i > 0 else 1,
dilation=(first_dilation, dilation),
depth=depths[i],
drop_path_rates=dp_rates[i],
ls_init_value=ls_init_value,
Expand All @@ -278,7 +300,15 @@ def __init__(
prev_chs = out_chs
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.num_features = prev_chs
self.head = head_fn(self.num_features, num_classes, drop=drop_rate)
if self.num_classes > 0:
if issubclass(head_fn, MlpClassifierHead):
assert self.global_pool, 'Cannot disable global pooling with MLP head present.'
self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate)
else:
if self.global_pool:
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
else:
self.head = nn.Identity()
self.apply(self._init_weights)

def _init_weights(self, m):
Expand All @@ -301,9 +331,18 @@ def group_matcher(self, coarse=False):
def get_classifier(self):
return self.head.fc2

def reset_classifier(self, num_classes=0, global_pool=None):
# FIXME
self.head.reset(num_classes, global_pool)
def reset_classifier(self, num_classes=0, global_pool=None, head_fn=MlpClassifierHead):
if global_pool is not None:
self.global_pool = global_pool
if num_classes > 0:
if issubclass(head_fn, MlpClassifierHead):
assert self.global_pool, 'Cannot disable global pooling with MLP head present.'
self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=self.drop_rate)
else:
if self.global_pool:
self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True)
else:
self.head = nn.Identity()

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
Expand All @@ -319,9 +358,12 @@ def forward_features(self, x):
x = self.stages(x)
return x

def forward_head(self, x):
x = self.head(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if pre_logits:
if hasattr(self.head, 'global_pool'):
x = self.head.global_pool(x)
return x
return self.head(x)

def forward(self, x):
x = self.forward_features(x)
Expand All @@ -342,18 +384,22 @@ def _cfg(url='', **kwargs):

default_cfgs = generate_default_cfgs({
'inception_next_tiny.sail_in1k': _cfg(
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
hf_hub_id='timm/',
# url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth',
),
'inception_next_small.sail_in1k': _cfg(
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
hf_hub_id='timm/',
# url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth',
),
'inception_next_base.sail_in1k': _cfg(
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
hf_hub_id='timm/',
# url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth',
crop_pct=0.95,
),
'inception_next_base.sail_in1k_384': _cfg(
url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
input_size=(3, 384, 384), crop_pct=1.0,
hf_hub_id='timm/',
# url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth',
input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0,
),
})

Expand Down

0 comments on commit ef2e6e6

Please sign in to comment.