diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index c063298312..b04ac8fabe 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -21,6 +21,7 @@ from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs @@ -178,18 +179,15 @@ def __init__( self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * num_heads + self.val_dim = int(attn_ratio * key_dim) + self.out_dim = self.val_dim * num_heads self.attn_ratio = attn_ratio self.resolution = resolution self.fused_attn = use_fused_attn() - h = self.dh + nh_kd * 2 - self.norm = nn.LayerNorm(dim) - self.qkv = nn.Linear(dim, h) - self.proj = nn.Linear(self.dh, dim) + self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim)) + self.proj = nn.Linear(self.out_dim, dim) points = list(itertools.product(range(resolution[0]), range(resolution[1]))) N = len(points) @@ -227,7 +225,7 @@ def forward(self, x): x = self.norm(x) qkv = self.qkv(x) # (B, N, num_heads, d) - q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3) # (B, num_heads, N, d) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) @@ -241,7 +239,7 @@ def forward(self, x): attn = attn + attn_bias attn = attn.softmax(dim=-1) x = attn @ v - x = x.transpose(1, 2).reshape(B, N, self.dh) + x = x.transpose(1, 2).reshape(B, N, self.out_dim) x = self.proj(x) return x @@ -311,7 +309,6 @@ def forward(self, x): pad_b = (self.window_size - H % self.window_size) % self.window_size pad_r = (self.window_size - W % self.window_size) % self.window_size padding = pad_b > 0 or pad_r > 0 - if padding: x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) @@ -344,6 +341,9 @@ def extra_repr(self) -> str: f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" +register_notrace_module(TinyVitBlock) + + class TinyVitStage(nn.Module): """ A basic TinyViT layer for one stage.