Skip to content

Commit

Permalink
Add full set of SigLIP models
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Oct 11, 2023
1 parent b9dde58 commit 42daa3b
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def __init__(
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
)
else:
Expand Down Expand Up @@ -1644,6 +1645,39 @@ def _cfg(url='', **kwargs):
input_size=(3, 256, 256),
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_384': _cfg(
file='',
custom_load=True,
input_size=(3, 384, 384),
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_512': _cfg(
file='',
custom_load=True,
input_size=(3, 512, 512),
# hf_hub_id='timm/',
num_classes=0),
'vit_large_patch16_siglip_256': _cfg(
custom_load=True,
input_size=(3, 256, 256),
# hf_hub_id='timm/',
num_classes=0),
'vit_large_patch16_siglip_384': _cfg(
custom_load=True,
input_size=(3, 384, 384),
# hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_224': _cfg(
# file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
custom_load=True,
# hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_384': _cfg(
#file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
custom_load=True,
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
})


Expand Down Expand Up @@ -2290,6 +2324,65 @@ def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer
return model


@register_model
def vit_base_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_base_patch16_siglip_512(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_large_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_large_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_so400m_patch14_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:
Expand Down

0 comments on commit 42daa3b

Please sign in to comment.