From ef2e6e6580fc65f19ac799292ce52872b10953a9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 24 Aug 2023 12:30:34 -0700 Subject: [PATCH] inception_next dilation support, weights on hf hub, classifier reset / global pool / no head fixes --- timm/models/inception_next.py | 102 ++++++++++++++++++++++++---------- 1 file changed, 74 insertions(+), 28 deletions(-) diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index df6bb4483e..da3a582bdf 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -8,7 +8,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, to_2tuple +from timm.layers import trunc_normal_, DropPath, to_2tuple, create_conv2d, get_padding, SelectAdaptivePool2d from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -23,16 +23,23 @@ def __init__( in_chs, square_kernel_size=3, band_kernel_size=11, - branch_ratio=0.125 + branch_ratio=0.125, + dilation=1, ): super().__init__() gc = int(in_chs * branch_ratio) # channel numbers of a convolution branch - self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size // 2, groups=gc) + square_padding = get_padding(square_kernel_size, dilation=dilation) + band_padding = get_padding(band_kernel_size, dilation=dilation) + self.dwconv_hw = nn.Conv2d( + gc, gc, square_kernel_size, + padding=square_padding, dilation=dilation, groups=gc) self.dwconv_w = nn.Conv2d( - gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size // 2), groups=gc) + gc, gc, (1, band_kernel_size), + padding=(0, band_padding), dilation=(1, dilation), groups=gc) self.dwconv_h = nn.Conv2d( - gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size // 2, 0), groups=gc) + gc, gc, (band_kernel_size, 1), + padding=(band_padding, 0), dilation=(dilation, 1), groups=gc) self.split_indexes = (in_chs - 3 * gc, gc, gc, gc) def forward(self, x): @@ -89,6 +96,7 @@ def __init__( self, dim, num_classes=1000, + pool_type='avg', mlp_ratio=3, act_layer=nn.GELU, norm_layer=partial(nn.LayerNorm, eps=1e-6), @@ -96,15 +104,17 @@ def __init__( bias=True ): super().__init__() - hidden_features = int(mlp_ratio * dim) - self.fc1 = nn.Linear(dim, hidden_features, bias=bias) + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True) + in_features = dim * self.global_pool.feat_mult() + hidden_features = int(mlp_ratio * in_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.norm = norm_layer(hidden_features) self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x): - x = x.mean((2, 3)) # global average pooling + x = self.global_pool(x) x = self.fc1(x) x = self.act(x) x = self.norm(x) @@ -124,7 +134,8 @@ class MetaNeXtBlock(nn.Module): def __init__( self, dim, - token_mixer=nn.Identity, + dilation=1, + token_mixer=InceptionDWConv2d, norm_layer=nn.BatchNorm2d, mlp_layer=ConvMlp, mlp_ratio=4, @@ -134,7 +145,7 @@ def __init__( ): super().__init__() - self.token_mixer = token_mixer(dim) + self.token_mixer = token_mixer(dim, dilation=dilation) self.norm = norm_layer(dim) self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer) self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None @@ -156,21 +167,28 @@ def __init__( self, in_chs, out_chs, - ds_stride=2, + stride=2, depth=2, + dilation=(1, 1), drop_path_rates=None, ls_init_value=1.0, - token_mixer=nn.Identity, + token_mixer=InceptionDWConv2d, act_layer=nn.GELU, norm_layer=None, mlp_ratio=4, ): super().__init__() self.grad_checkpointing = False - if ds_stride > 1: + if stride > 1 or dilation[0] != dilation[1]: self.downsample = nn.Sequential( norm_layer(in_chs), - nn.Conv2d(in_chs, out_chs, kernel_size=ds_stride, stride=ds_stride), + nn.Conv2d( + in_chs, + out_chs, + kernel_size=2, + stride=stride, + dilation=dilation[0], + ), ) else: self.downsample = nn.Identity() @@ -180,6 +198,7 @@ def __init__( for i in range(depth): stage_blocks.append(MetaNeXtBlock( dim=out_chs, + dilation=dilation[1], drop_path=drop_path_rates[i], ls_init_value=ls_init_value, token_mixer=token_mixer, @@ -221,10 +240,11 @@ def __init__( self, in_chans=3, num_classes=1000, + global_pool='avg', output_stride=32, depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), - token_mixers=nn.Identity, + token_mixers=InceptionDWConv2d, norm_layer=nn.BatchNorm2d, act_layer=nn.GELU, mlp_ratios=(4, 4, 4, 3), @@ -241,6 +261,7 @@ def __init__( if not isinstance(mlp_ratios, (list, tuple)): mlp_ratios = [mlp_ratios] * num_stage self.num_classes = num_classes + self.global_pool = global_pool self.drop_rate = drop_rate self.feature_info = [] @@ -266,7 +287,8 @@ def __init__( self.stages.append(MetaNeXtStage( prev_chs, out_chs, - ds_stride=2 if i > 0 else 1, + stride=stride if i > 0 else 1, + dilation=(first_dilation, dilation), depth=depths[i], drop_path_rates=dp_rates[i], ls_init_value=ls_init_value, @@ -278,7 +300,15 @@ def __init__( prev_chs = out_chs self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.num_features = prev_chs - self.head = head_fn(self.num_features, num_classes, drop=drop_rate) + if self.num_classes > 0: + if issubclass(head_fn, MlpClassifierHead): + assert self.global_pool, 'Cannot disable global pooling with MLP head present.' + self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate) + else: + if self.global_pool: + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): @@ -301,9 +331,18 @@ def group_matcher(self, coarse=False): def get_classifier(self): return self.head.fc2 - def reset_classifier(self, num_classes=0, global_pool=None): - # FIXME - self.head.reset(num_classes, global_pool) + def reset_classifier(self, num_classes=0, global_pool=None, head_fn=MlpClassifierHead): + if global_pool is not None: + self.global_pool = global_pool + if num_classes > 0: + if issubclass(head_fn, MlpClassifierHead): + assert self.global_pool, 'Cannot disable global pooling with MLP head present.' + self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=self.drop_rate) + else: + if self.global_pool: + self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) + else: + self.head = nn.Identity() @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -319,9 +358,12 @@ def forward_features(self, x): x = self.stages(x) return x - def forward_head(self, x): - x = self.head(x) - return x + def forward_head(self, x, pre_logits: bool = False): + if pre_logits: + if hasattr(self.head, 'global_pool'): + x = self.head.global_pool(x) + return x + return self.head(x) def forward(self, x): x = self.forward_features(x) @@ -342,18 +384,22 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ 'inception_next_tiny.sail_in1k': _cfg( - url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth', + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth', ), 'inception_next_small.sail_in1k': _cfg( - url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth', + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth', ), 'inception_next_base.sail_in1k': _cfg( - url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth', + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth', crop_pct=0.95, ), 'inception_next_base.sail_in1k_384': _cfg( - url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth', - input_size=(3, 384, 384), crop_pct=1.0, + hf_hub_id='timm/', + # url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth', + input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0, ), })