diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 107628dba8..ff753679c0 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -606,6 +606,7 @@ def __init__( self.attn_pool = AttentionPoolLatent( self.embed_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, norm_layer=norm_layer, ) else: @@ -1644,6 +1645,39 @@ def _cfg(url='', **kwargs): input_size=(3, 256, 256), # hf_hub_id='timm/', num_classes=0), + 'vit_base_patch16_siglip_384': _cfg( + file='', + custom_load=True, + input_size=(3, 384, 384), + # hf_hub_id='timm/', + num_classes=0), + 'vit_base_patch16_siglip_512': _cfg( + file='', + custom_load=True, + input_size=(3, 512, 512), + # hf_hub_id='timm/', + num_classes=0), + 'vit_large_patch16_siglip_256': _cfg( + custom_load=True, + input_size=(3, 256, 256), + # hf_hub_id='timm/', + num_classes=0), + 'vit_large_patch16_siglip_384': _cfg( + custom_load=True, + input_size=(3, 384, 384), + # hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_224': _cfg( + # file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', + custom_load=True, + # hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_384': _cfg( + #file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', + custom_load=True, + # hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), }) @@ -2290,6 +2324,65 @@ def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer return model +@register_model +def vit_base_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_512(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + @register_model def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: