diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index b04ac8fabe..b3a6009cd7 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -451,7 +451,7 @@ def __init__( dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # build stages - stages = nn.ModuleList() + self.stages = nn.Sequential() stride = self.patch_embed.stride prev_dim = embed_dims[0] self.feature_info = [] @@ -482,9 +482,8 @@ def __init__( ) prev_dim = out_dim stride *= 2 - stages.append(stage) + self.stages.append(stage) self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')] - self.stages = nn.Sequential(*stages) # Classifier head self.num_features = embed_dims[-1] @@ -549,22 +548,17 @@ def forward(self, x): def checkpoint_filter_fn(state_dict, model): - # TODO: temporary use for testing, need change after weight convert if 'model' in state_dict.keys(): state_dict = state_dict['model'] target_sd = model.state_dict() - target_keys = list(target_sd.keys()) out_dict = {} - i = 0 for k, v in state_dict.items(): if k.endswith('attention_bias_idxs'): continue - tk = target_keys[i] if 'attention_biases' in k: # TODO: whether move this func into model for dynamic input resolution? (high risk) - v = resize_rel_pos_bias_table_levit(v.T, target_sd[tk].shape[::-1]).T - out_dict[tk] = v - i += 1 + v = resize_rel_pos_bias_table_levit(v.T, target_sd[k].shape[::-1]).T + out_dict[k] = v return out_dict @@ -585,41 +579,52 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ 'tiny_vit_5m_224.dist_in22k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth', + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth', num_classes=21841 ), 'tiny_vit_5m_224.dist_in22k_ft_in1k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth' + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth' ), 'tiny_vit_5m_224.in1k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth' + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth' ), 'tiny_vit_11m_224.dist_in22k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth', + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth', num_classes=21841 ), 'tiny_vit_11m_224.dist_in22k_ft_in1k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth' + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth' ), 'tiny_vit_11m_224.in1k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth' + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth' ), 'tiny_vit_21m_224.dist_in22k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth', + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth', num_classes=21841 ), 'tiny_vit_21m_224.dist_in22k_ft_in1k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth' + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth' ), 'tiny_vit_21m_224.in1k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth' + hf_hub_id='timm/', + #url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth' ), 'tiny_vit_21m_384.dist_in22k_ft_in1k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth', + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, ), 'tiny_vit_21m_512.dist_in22k_ft_in1k': _cfg( - url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth', + hf_hub_id='timm/', + # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth', input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash', ), })