Skip to content

Commit

Permalink
TinyVitBlock needs adding as leaf for FX now, tweak a few dim names
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 1, 2023
1 parent 9caf32b commit 507cb08
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions timm/models/tiny_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq
from ._registry import register_model, generate_default_cfgs

Expand Down Expand Up @@ -178,18 +179,15 @@ def __init__(
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.val_dim = int(attn_ratio * key_dim)
self.out_dim = self.val_dim * num_heads
self.attn_ratio = attn_ratio
self.resolution = resolution
self.fused_attn = use_fused_attn()

h = self.dh + nh_kd * 2

self.norm = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, h)
self.proj = nn.Linear(self.dh, dim)
self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim))
self.proj = nn.Linear(self.out_dim, dim)

points = list(itertools.product(range(resolution[0]), range(resolution[1])))
N = len(points)
Expand Down Expand Up @@ -227,7 +225,7 @@ def forward(self, x):
x = self.norm(x)
qkv = self.qkv(x)
# (B, N, num_heads, d)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
# (B, num_heads, N, d)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
Expand All @@ -241,7 +239,7 @@ def forward(self, x):
attn = attn + attn_bias
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, self.dh)
x = x.transpose(1, 2).reshape(B, N, self.out_dim)
x = self.proj(x)
return x

Expand Down Expand Up @@ -311,7 +309,6 @@ def forward(self, x):
pad_b = (self.window_size - H % self.window_size) % self.window_size
pad_r = (self.window_size - W % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0

if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))

Expand Down Expand Up @@ -344,6 +341,9 @@ def extra_repr(self) -> str:
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"


register_notrace_module(TinyVitBlock)


class TinyVitStage(nn.Module):
""" A basic TinyViT layer for one stage.
Expand Down

0 comments on commit 507cb08

Please sign in to comment.