Skip to content

Commit

Permalink
TinyViT weights on HF hub
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 1, 2023
1 parent 507cb08 commit 2f0fbb5
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions timm/models/tiny_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def __init__(
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

# build stages
stages = nn.ModuleList()
self.stages = nn.Sequential()
stride = self.patch_embed.stride
prev_dim = embed_dims[0]
self.feature_info = []
Expand Down Expand Up @@ -482,9 +482,8 @@ def __init__(
)
prev_dim = out_dim
stride *= 2
stages.append(stage)
self.stages.append(stage)
self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')]
self.stages = nn.Sequential(*stages)

# Classifier head
self.num_features = embed_dims[-1]
Expand Down Expand Up @@ -549,22 +548,17 @@ def forward(self, x):


def checkpoint_filter_fn(state_dict, model):
# TODO: temporary use for testing, need change after weight convert
if 'model' in state_dict.keys():
state_dict = state_dict['model']
target_sd = model.state_dict()
target_keys = list(target_sd.keys())
out_dict = {}
i = 0
for k, v in state_dict.items():
if k.endswith('attention_bias_idxs'):
continue
tk = target_keys[i]
if 'attention_biases' in k:
# TODO: whether move this func into model for dynamic input resolution? (high risk)
v = resize_rel_pos_bias_table_levit(v.T, target_sd[tk].shape[::-1]).T
out_dict[tk] = v
i += 1
v = resize_rel_pos_bias_table_levit(v.T, target_sd[k].shape[::-1]).T
out_dict[k] = v
return out_dict


Expand All @@ -585,41 +579,52 @@ def _cfg(url='', **kwargs):

default_cfgs = generate_default_cfgs({
'tiny_vit_5m_224.dist_in22k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth',
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth',
num_classes=21841
),
'tiny_vit_5m_224.dist_in22k_ft_in1k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth'
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth'
),
'tiny_vit_5m_224.in1k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth'
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth'
),
'tiny_vit_11m_224.dist_in22k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth',
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth',
num_classes=21841
),
'tiny_vit_11m_224.dist_in22k_ft_in1k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth'
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth'
),
'tiny_vit_11m_224.in1k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth'
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth'
),
'tiny_vit_21m_224.dist_in22k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth',
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth',
num_classes=21841
),
'tiny_vit_21m_224.dist_in22k_ft_in1k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth'
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth'
),
'tiny_vit_21m_224.in1k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth'
hf_hub_id='timm/',
#url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth'
),
'tiny_vit_21m_384.dist_in22k_ft_in1k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth',
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
),
'tiny_vit_21m_512.dist_in22k_ft_in1k': _cfg(
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth',
hf_hub_id='timm/',
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash',
),
})
Expand Down

0 comments on commit 2f0fbb5

Please sign in to comment.