diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 8f56b5f156..f308a580b8 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -40,6 +40,7 @@ from .pnasnet import * from .pvt_v2 import * from .regnet import * +from .repvit import * from .res2net import * from .resnest import * from .resnet import * diff --git a/timm/models/repvit.py b/timm/models/repvit.py new file mode 100644 index 0000000000..e5e32880a6 --- /dev/null +++ b/timm/models/repvit.py @@ -0,0 +1,404 @@ +""" RepViT + +Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective` + - https://arxiv.org/abs/2307.09283 + +@misc{wang2023repvit, + title={RepViT: Revisiting Mobile CNN From ViT Perspective}, + author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding}, + year={2023}, + eprint={2307.09283}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + +Adapted from official impl at https://github.com/jameslahm/RepViT +""" + +__all__ = ['RepViT'] + +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from ._registry import register_model, generate_default_cfgs +from ._builder import build_model_with_cfg +from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple +from ._manipulate import checkpoint_seq + +import torch + + +class ConvNorm(nn.Sequential): + def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', nn.Conv2d(in_dim, out_dim, ks, stride, pad, dilation, groups, bias=False)) + self.add_module('bn', nn.BatchNorm2d(out_dim)) + nn.init.constant_(self.bn.weight, bn_weight_init) + nn.init.constant_(self.bn.bias, 0) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Conv2d( + w.size(1) * self.c.groups, + w.size(0), + w.shape[2:], + stride=self.c.stride, + padding=self.c.padding, + dilation=self.c.dilation, + groups=self.c.groups, + device=c.weight.device, + ) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class NormLinear(nn.Sequential): + def __init__(self, in_dim, out_dim, bias=True, std=0.02): + super().__init__() + self.add_module('bn', nn.BatchNorm1d(in_dim)) + self.add_module('l', nn.Linear(in_dim, out_dim, bias=bias)) + trunc_normal_(self.l.weight, std=std) + if bias: + nn.init.constant_(self.l.bias, 0) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = nn.Linear(w.size(1), w.size(0), device=l.weight.device) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class RepVGGDW(nn.Module): + def __init__(self, ed, kernel_size): + super().__init__() + self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed) + self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed) + self.dim = ed + + def forward(self, x): + return self.conv(x) + self.conv1(x) + x + + @torch.no_grad() + def fuse(self): + conv = self.conv.fuse() + conv1 = self.conv1.fuse() + + conv_w = conv.weight + conv_b = conv.bias + conv1_w = conv1.weight + conv1_b = conv1.bias + + conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1]) + + identity = nn.functional.pad( + torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1, 1, 1, 1] + ) + + final_conv_w = conv_w + conv1_w + identity + final_conv_b = conv_b + conv1_b + + conv.weight.data.copy_(final_conv_w) + conv.bias.data.copy_(final_conv_b) + return conv + + +class RepViTMlp(nn.Module): + def __init__(self, in_dim, hidden_dim, act_layer): + super().__init__() + self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0) + self.act = act_layer() + self.conv2 = ConvNorm(hidden_dim, in_dim, 1, 1, 0, bn_weight_init=0) + + def forward(self, x): + return self.conv2(self.act(self.conv1(x))) + + +class RepViTBlock(nn.Module): + def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer): + super(RepViTBlock, self).__init__() + + self.token_mixer = RepVGGDW(in_dim, kernel_size) + self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity() + self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer) + + def forward(self, x): + x = self.token_mixer(x) + x = self.se(x) + identity = x + x = self.channel_mixer(x) + return identity + x + + +class RepViTStem(nn.Module): + def __init__(self, in_chs, out_chs, act_layer): + super().__init__() + self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) + self.act1 = act_layer() + self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1) + self.stride = 4 + + def forward(self, x): + return self.conv2(self.act1(self.conv1(x))) + + +class RepViTDownsample(nn.Module): + def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer): + super().__init__() + self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer) + self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim) + self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1) + self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer) + + def forward(self, x): + x = self.pre_block(x) + x = self.spatial_downsample(x) + x = self.channel_downsample(x) + identity = x + x = self.ffn(x) + return x + identity + + +class RepViTClassifier(nn.Module): + def __init__(self, dim, num_classes, distillation=False): + super().__init__() + self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.distillation = distillation + if distillation: + self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.num_classes = num_classes + + def forward(self, x): + if self.distillation: + x1, x2 = self.head(x), self.head_dist(x) + if (not self.training) or torch.jit.is_scripting(): + return (x1 + x2) / 2 + else: + return x1, x2 + else: + x = self.head(x) + return x + + @torch.no_grad() + def fuse(self): + if not self.num_classes > 0: + return nn.Identity() + head = self.head.fuse() + if self.distillation: + head_dist = self.head_dist.fuse() + head.weight += head_dist.weight + head.bias += head_dist.bias + head.weight /= 2 + head.bias /= 2 + return head + else: + return head + + +class RepViTStage(nn.Module): + def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True): + super().__init__() + if downsample: + self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer) + else: + assert in_dim == out_dim + self.downsample = nn.Identity() + + blocks = [] + use_se = True + for _ in range(depth): + blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer)) + use_se = not use_se + + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class RepViT(nn.Module): + def __init__( + self, + in_chans=3, + img_size=224, + embed_dim=(48,), + depth=(2,), + mlp_ratio=2, + global_pool='avg', + kernel_size=3, + num_classes=1000, + act_layer=nn.GELU, + distillation=True, + ): + super(RepViT, self).__init__() + self.grad_checkpointing = False + self.global_pool = global_pool + self.embed_dim = embed_dim + + in_dim = embed_dim[0] + self.stem = RepViTStem(in_chans, in_dim, act_layer) + stride = self.stem.stride + resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))]) + + num_stages = len(embed_dim) + mlp_ratios = to_ntuple(num_stages)(mlp_ratio) + + self.feature_info = [] + stages = [] + for i in range(num_stages): + downsample = True if i != 0 else False + stages.append( + RepViTStage( + in_dim, + embed_dim[i], + depth[i], + mlp_ratio=mlp_ratios[i], + act_layer=act_layer, + kernel_size=kernel_size, + downsample=downsample, + ) + ) + stage_stride = 2 if downsample else 1 + stride *= stage_stride + resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution]) + self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')] + in_dim = embed_dim[i] + self.stages = nn.Sequential(*stages) + + self.num_features = embed_dim[-1] + self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None, distillation=False): + self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool + self.head = ( + RepViTClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity() + ) + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + @torch.no_grad() + def fuse(self): + def fuse_children(net): + for child_name, child in net.named_children(): + if hasattr(child, 'fuse'): + fused = child.fuse() + setattr(net, child_name, fused) + fuse_children(fused) + else: + fuse_children(child) + + fuse_children(self) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': (7, 7), + 'crop_pct': 0.95, + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1.c', + 'classifier': ('head.head.l', 'head.head_dist.l'), + **kwargs, + } + + +default_cfgs = generate_default_cfgs( + { + 'repvit_m1': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth' + ), + 'repvit_m2': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth' + ), + 'repvit_m3': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth' + ), + } +) + + +def _create_repvit(variant, pretrained=False, **kwargs): + out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) + model = build_model_with_cfg( + RepViT, variant, pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs + ) + return model + + +@register_model +def repvit_m1(pretrained=False, **kwargs): + """ + Constructs a RepViT-M1 model + """ + model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2)) + return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def repvit_m2(pretrained=False, **kwargs): + """ + Constructs a RepViT-M2 model + """ + model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2)) + return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def repvit_m3(pretrained=False, **kwargs): + """ + Constructs a RepViT-M3 model + """ + model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2)) + return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))