diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 7b6e2ce85b..c8a6007dd0 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -7,8 +7,10 @@ """ __all__ = ['TinyVit'] + import math import itertools +from functools import partial from typing import Dict import torch @@ -16,7 +18,8 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, trunc_normal_, resample_relative_position_bias_table, _assert +from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ + to_2tuple, trunc_normal_, resample_relative_position_bias_table, use_fused_attn from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -33,10 +36,10 @@ def __init__(self, in_chs, out_chs, ks=1, stride=1, pad=0, dilation=1, groups=1, @torch.no_grad() def fuse(self): c, bn = self.conv, self.bn - w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / \ - (bn.running_var + bn.eps)**0.5 + (bn.running_var + bn.eps) ** 0.5 m = torch.nn.Conv2d( w.size(1) * self.conv.groups, w.size(0), w.shape[2:], stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups) @@ -46,18 +49,12 @@ def fuse(self): class PatchEmbed(nn.Module): - def __init__(self, in_chans, embed_dim, resolution, activation): + def __init__(self, in_chs, out_chs, act_layer): super().__init__() - img_size = to_2tuple(resolution) - self.patches_resolution = (math.ceil(img_size[0] / 4), math.ceil(img_size[1] / 4)) - self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] - self.in_chans = in_chans - self.embed_dim = embed_dim self.stride = 4 - n = embed_dim - self.conv1 = ConvNorm(self.in_chans, n // 2, 3, 2, 1) - self.act = activation() - self.conv2 = ConvNorm(n // 2, n, 3, 2, 1) + self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) + self.act = act_layer() + self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1) def forward(self, x): x = self.conv1(x) @@ -67,17 +64,15 @@ def forward(self, x): class MBConv(nn.Module): - def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): + def __init__(self, in_chs, out_chs, expand_ratio, act_layer, drop_path): super().__init__() - self.in_chans = in_chans - self.hidden_chans = int(in_chans * expand_ratio) - self.out_chans = out_chans - self.conv1 = ConvNorm(in_chans, self.hidden_chans, ks=1) - self.act1 = activation() - self.conv2 = ConvNorm(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans) - self.act2 = activation() - self.conv3 = ConvNorm(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) - self.act3 = activation() + mid_chs = int(in_chs * expand_ratio) + self.conv1 = ConvNorm(in_chs, mid_chs, ks=1) + self.act1 = act_layer() + self.conv2 = ConvNorm(mid_chs, mid_chs, ks=3, stride=1, pad=1, groups=mid_chs) + self.act2 = act_layer() + self.conv3 = ConvNorm(mid_chs, out_chs, ks=1, bn_weight_init=0.0) + self.act3 = act_layer() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): @@ -94,97 +89,90 @@ def forward(self, x): class PatchMerging(nn.Module): - def __init__(self, input_resolution, dim, out_dim, activation, in_fmt='BCHW'): + def __init__(self, dim, out_dim, act_layer): super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.out_dim = out_dim - self.act = activation() self.conv1 = ConvNorm(dim, out_dim, 1, 1, 0) + self.act1 = act_layer() self.conv2 = ConvNorm(out_dim, out_dim, 3, 2, 1, groups=out_dim) + self.act2 = act_layer() self.conv3 = ConvNorm(out_dim, out_dim, 1, 1, 0) - self.output_resolution = (math.ceil(input_resolution[0] / 2), math.ceil(input_resolution[1] / 2)) - self.in_fmt = in_fmt - assert self.in_fmt in ['BCHW', 'BLC'] def forward(self, x): - if self.in_fmt == 'BLC': - # (B, H * W, C) -> (B, C, H, W) - H, W = self.input_resolution - B = x.shape[0] - x = x.view(B, H, W, -1).permute(0, 3, 1, 2) x = self.conv1(x) - x = self.act(x) + x = self.act1(x) x = self.conv2(x) - x = self.act(x) + x = self.act2(x) x = self.conv3(x) - # (B, C, H, W) -> (B, H * W, C) - x = x.flatten(2).transpose(1, 2) return x class ConvLayer(nn.Module): - def __init__(self, dim, input_resolution, depth, activation, drop_path=0., - downsample=None, conv_expand_ratio=4.): + def __init__( + self, + dim, + depth, + act_layer, + drop_path=0., + conv_expand_ratio=4., + ): super().__init__() self.dim = dim - self.input_resolution = input_resolution self.depth = depth - # build blocks self.blocks = nn.Sequential(*[ - MBConv(dim, dim, conv_expand_ratio, activation, - drop_path[i] if isinstance(drop_path, list) else drop_path, - ) - for i in range(depth)]) + MBConv( + dim, dim, conv_expand_ratio, act_layer, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ]) def forward(self, x): x = self.blocks(x) return x -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, - out_features=None, act_layer=nn.GELU, drop=0.): +class NormMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + drop=0., + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.norm = nn.LayerNorm(in_features) + self.norm = norm_layer(in_features) self.fc1 = nn.Linear(in_features, hidden_features) - self.fc2 = nn.Linear(hidden_features, out_features) self.act = act_layer() - self.drop = nn.Dropout(drop) + self.drop1 = nn.Dropout(drop) + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop2 = nn.Dropout(drop) def forward(self, x): x = self.norm(x) x = self.fc1(x) x = self.act(x) - x = self.drop(x) + x = self.drop1(x) x = self.fc2(x) - x = self.drop(x) - return x - - -class ClassifierHead(nn.Module): - def __init__( - self, - in_channels, - num_classes=1000 - ): - super(ClassifierHead, self).__init__() - self.norm_head = nn.LayerNorm(in_channels) - self.fc = nn.Linear(in_channels, num_classes) if num_classes > 0 else nn.Identity() - - def forward(self, x): - x = x.mean(1) - x = self.norm_head(x) - x = self.fc(x) + x = self.drop2(x) return x class Attention(torch.nn.Module): + fused_attn: torch.jit.Final[bool] attention_bias_cache: Dict[str, torch.Tensor] - def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)): + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): super().__init__() assert isinstance(resolution, tuple) and len(resolution) == 2 self.num_heads = num_heads @@ -194,6 +182,8 @@ def __init__(self, dim, key_dim, num_heads=8, attn_ratio=4, resolution=(14, 14)) self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * num_heads self.attn_ratio = attn_ratio + self.fused_attn = use_fused_attn() + h = self.dh + nh_kd * 2 self.norm = nn.LayerNorm(dim) @@ -242,12 +232,15 @@ def forward(self, x): k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn + attn_bias - attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2) - x = x.reshape(B, N, self.dh) + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn + attn_bias + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x @@ -257,7 +250,6 @@ class TinyVitBlock(nn.Module): Args: dim (int): Number of input channels. - input_resolution (tuple[int, int]): Input resulotion. num_heads (int): Number of attention heads. window_size (int): Window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. @@ -265,94 +257,89 @@ class TinyVitBlock(nn.Module): drop_path (float, optional): Stochastic depth rate. Default: 0.0 local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3 - activation: the activation function. Default: nn.GELU + act_layer: the activation function. Default: nn.GELU """ def __init__( - self, - dim, - input_resolution, - num_heads, - window_size=7, - mlp_ratio=4., - drop=0., - drop_path=0., - local_conv_size=3, - activation=nn.GELU + self, + dim, + num_heads, + window_size=7, + mlp_ratio=4., + drop=0., + drop_path=0., + local_conv_size=3, + act_layer=nn.GELU ): super().__init__() self.dim = dim - self.input_resolution = input_resolution self.num_heads = num_heads assert window_size > 0, 'window_size must be greater than 0' self.window_size = window_size self.mlp_ratio = mlp_ratio - self.drop_path = DropPath( - drop_path) if drop_path > 0. else nn.Identity() - assert dim % num_heads == 0, 'dim must be divisible by num_heads' head_dim = dim // num_heads window_resolution = (window_size, window_size) - self.attn = Attention(dim, head_dim, num_heads, - attn_ratio=1, resolution=window_resolution) + self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + - mlp_hidden_dim = int(dim * mlp_ratio) - mlp_activation = activation - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, - act_layer=mlp_activation, drop=drop) + self.mlp = NormMlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() pad = local_conv_size // 2 - self.local_conv = ConvNorm( - dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + self.local_conv = ConvNorm(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - _assert(L == H * W, f"input feature has wrong size, expect {H * W}, got {L}") - res_x = x + B, H, W, C = x.shape + L = H * W + + shortcut = x if H == self.window_size and W == self.window_size: + x = x.reshape(B, L, C) x = self.attn(x) - else: x = x.view(B, H, W, C) - pad_b = (self.window_size - H % - self.window_size) % self.window_size - pad_r = (self.window_size - W % - self.window_size) % self.window_size + else: + pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_r = (self.window_size - W % self.window_size) % self.window_size padding = pad_b > 0 or pad_r > 0 if padding: x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + # window partition pH, pW = H + pad_b, W + pad_r nH = pH // self.window_size nW = pW // self.window_size - # window partition x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( B * nH * nW, self.window_size * self.window_size, C ) + x = self.attn(x) + # window reverse - x = x.view(B, nH, nW, self.window_size, self.window_size, - C).transpose(2, 3).reshape(B, pH, pW, C) + x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C) if padding: x = x[:, :H, :W].contiguous() + x = shortcut + self.drop_path1(x) - x = x.view(B, L, C) - - x = res_x + self.drop_path(x) - - x = x.transpose(1, 2).reshape(B, C, H, W) + x = x.permute(0, 3, 1, 2) x = self.local_conv(x) - x = x.view(B, C, L).transpose(1, 2) + x = x.reshape(B, C, L).transpose(1, 2) - x = x + self.drop_path(self.mlp(x)) - return x + x = x + self.drop_path2(self.mlp(x)) + return x.view(B, H, W, C) def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + return f"dim={self.dim}, num_heads={self.num_heads}, " \ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" @@ -361,7 +348,7 @@ class TinyVitStage(nn.Module): Args: dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. + out_dim: the output dimension of the layer depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (int): Local window size. @@ -370,82 +357,80 @@ class TinyVitStage(nn.Module): drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 - activation: the activation function. Default: nn.GELU - out_dim: the output dimension of the layer. Default: dim - in_fmt: input format ('BCHW' or 'BLC'). Default: 'BCHW' + act_layer: the activation function. Default: nn.GELU """ def __init__( - self, - input_dim, - input_resolution, - depth, - num_heads, - window_size, - mlp_ratio=4., - drop=0., - drop_path=0., - downsample=None, - local_conv_size=3, - activation=nn.GELU, - out_dim=None, - in_fmt='BCHW' + self, + dim, + out_dim, + depth, + num_heads, + window_size, + mlp_ratio=4., + drop=0., + drop_path=0., + downsample=None, + local_conv_size=3, + act_layer=nn.GELU, ): super().__init__() - self.input_dim = input_dim - self.out_dim = out_dim - self.input_resolution = input_resolution self.depth = depth # patch merging layer if downsample is not None: self.downsample = downsample( - input_resolution, dim=input_dim, out_dim=self.out_dim, activation=activation, in_fmt=in_fmt) - input_resolution = self.downsample.output_resolution + dim=dim, + out_dim=out_dim, + act_layer=act_layer, + ) else: self.downsample = nn.Identity() - self.out_dim = self.input_dim + assert dim == out_dim # build blocks self.blocks = nn.Sequential(*[ - TinyVitBlock(dim=self.out_dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - mlp_ratio=mlp_ratio, - drop=drop, - drop_path=drop_path[i] if isinstance( - drop_path, list) else drop_path, - local_conv_size=local_conv_size, - activation=activation, - ) + TinyVitBlock( + dim=out_dim, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + local_conv_size=local_conv_size, + act_layer=act_layer, + ) for i in range(depth)]) def forward(self, x): x = self.downsample(x) + x = x.permute(0, 2, 3, 1) # BCHW -> BHWC x = self.blocks(x) + x = x.permute(0, 3, 1, 2) # BHWC -> BCHW return x def extra_repr(self) -> str: - return f"dim={self.out_dim}, input_resolution={self.input_resolution}, depth={self.depth}" + return f"dim={self.out_dim}, depth={self.depth}" class TinyVit(nn.Module): def __init__( - self, - img_size=224, - in_chans=3, - num_classes=1000, - embed_dims=[96, 192, 384, 768], - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_sizes=[7, 7, 14, 7], - mlp_ratio=4., - drop_rate=0., - drop_path_rate=0.1, - use_checkpoint=False, - mbconv_expand_ratio=4.0, - local_conv_size=3, - layer_lr_decay=1.0 + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + embed_dims=(96, 192, 384, 768), + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_sizes=(7, 7, 14, 7), + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + act_layer=nn.GELU, ): super().__init__() @@ -455,113 +440,70 @@ def __init__( self.mlp_ratio = mlp_ratio self.grad_checkpointing = use_checkpoint - activation = nn.GELU - - self.patch_embed = PatchEmbed(in_chans=in_chans, - embed_dim=embed_dims[0], - resolution=img_size, - activation=activation) - - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution + self.patch_embed = PatchEmbed( + in_chs=in_chans, + out_chs=embed_dims[0], + act_layer=act_layer, + ) # stochastic depth rate rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # build stages stages = nn.ModuleList() - input_resolution = patches_resolution stride = self.patch_embed.stride + prev_dim = embed_dims[0] self.feature_info = [] for stage_idx in range(self.num_stages): if stage_idx == 0: - out_dim = embed_dims[0] stage = ConvLayer( - dim=embed_dims[0], - input_resolution=input_resolution, - depth=depths[0], - activation=activation, - drop_path=dpr[:depths[0]], - downsample=None, + dim=prev_dim, + depth=depths[stage_idx], + act_layer=act_layer, + drop_path=dpr[:depths[stage_idx]], conv_expand_ratio=mbconv_expand_ratio, ) else: out_dim = embed_dims[stage_idx] drop_path_rate = dpr[sum(depths[:stage_idx]):sum(depths[:stage_idx + 1])] - if stage_idx == 1: - in_fmt = 'BCHW' - else: - in_fmt = 'BLC' stage = TinyVitStage( + dim=embed_dims[stage_idx - 1], + out_dim=out_dim, + depth=depths[stage_idx], num_heads=num_heads[stage_idx], window_size=window_sizes[stage_idx], mlp_ratio=self.mlp_ratio, drop=drop_rate, local_conv_size=local_conv_size, - input_dim=embed_dims[stage_idx - 1], - input_resolution=input_resolution, - depth=depths[stage_idx], drop_path=drop_path_rate, downsample=PatchMerging, - out_dim=out_dim, - activation=activation, - in_fmt=in_fmt + act_layer=act_layer, ) - input_resolution = (math.ceil(input_resolution[0] / 2), math.ceil(input_resolution[1] / 2)) + prev_dim = out_dim stride *= 2 stages.append(stage) - self.feature_info += [dict(num_chs=out_dim, reduction=stride, module=f'stages.{stage_idx}')] + self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')] self.stages = nn.Sequential(*stages) # Classifier head self.num_features = embed_dims[-1] - self.head = ClassifierHead(self.num_features, num_classes=num_classes) + + norm_layer_cf = partial(LayerNorm2d, eps=1e-5) + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + norm_layer=norm_layer_cf, + ) # init weights self.apply(self._init_weights) - self.set_layer_lr_decay(layer_lr_decay) - - @torch.jit.ignore - def set_layer_lr_decay(self, layer_lr_decay): - decay_rate = layer_lr_decay - - # stages -> blocks (depth) - depth = sum(self.depths) - lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] - - def _set_lr_scale(m, scale): - for p in m.parameters(): - p.lr_scale = scale - - self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) - i = 0 - for stage in self.stages: - if hasattr(stage, 'downsample') and stage.downsample is not None: - stage.downsample.apply( - lambda x: _set_lr_scale(x, lr_scales[i])) - for block in stage.blocks: - block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) - i += 1 - assert i == depth - self.head.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) - - for k, p in self.named_parameters(): - p.param_name = k - - def _check_lr_scale(m): - for p in m.parameters(): - assert hasattr(p, 'lr_scale'), p.param_name - - self.apply(_check_lr_scale) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay_keywords(self): @@ -583,9 +525,9 @@ def set_grad_checkpointing(self, enable=True): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, **kwargs): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - self.head = ClassifierHead(self.num_features, num_classes=num_classes) + self.head.reset(num_classes, global_pool=global_pool) def forward_features(self, x): x = self.patch_embed(x) @@ -618,7 +560,7 @@ def checkpoint_filter_fn(state_dict, model): if 'attention_biases' in k: # dynamic window size by resampling relative_position_bias_table # TODO: whether move this func into model for dynamic input resolution? (high risk) - v = resample_relative_position_bias_table(v, targe_sd[target_keys[i]].shape) + v = resample_relative_position_bias_table(v.T, targe_sd[target_keys[i]].shape[::-1]).T out_dict[target_keys[i]] = v i += 1 return out_dict @@ -632,9 +574,9 @@ def _cfg(url='', **kwargs): 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed.conv1.conv', 'classifier': 'head.fc', - 'fixed_input_size': True, - 'pool_size': None, + 'pool_size': (7, 7), 'input_size': (3, 224, 224), + 'crop_pct': 0.95, **kwargs, } @@ -672,11 +614,11 @@ def _cfg(url='', **kwargs): ), 'tiny_vit_21m_384.dist_in22k_ft_in1k': _cfg( url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth', - input_size=(3, 384, 384) + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, ), 'tiny_vit_21m_512.dist_in22k_ft_in1k': _cfg( url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth', - input_size=(3, 512, 512) + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash', ), }) @@ -736,7 +678,6 @@ def tiny_vit_21m_224(pretrained=False, **kwargs): @register_model def tiny_vit_21m_384(pretrained=False, **kwargs): model_kwargs = dict( - img_size=384, embed_dims=[96, 192, 384, 576], depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 18], @@ -750,7 +691,6 @@ def tiny_vit_21m_384(pretrained=False, **kwargs): @register_model def tiny_vit_21m_512(pretrained=False, **kwargs): model_kwargs = dict( - img_size=512, embed_dims=[96, 192, 384, 576], depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 18],