Skip to content

Commit

Permalink
Fixup attention pooling in siglip vit support
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Oct 2, 2023
1 parent 99cfd67 commit b9dde58
Showing 1 changed file with 58 additions and 16 deletions.
74 changes: 58 additions & 16 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def __init__(
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
latent_size: int = 1,
latent_len: int = 1,
latent_dim: int = None,
pos_embed: str = '',
pool_type: str = 'token',
Expand All @@ -404,6 +404,7 @@ def __init__(
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.pool = pool_type
self.fused_attn = use_fused_attn()

if pos_embed == 'abs':
spatial_len = self.feat_size
Expand All @@ -412,11 +413,16 @@ def __init__(
self.pos_embed = None

self.latent_dim = latent_dim or embed_dim
latent_size = latent_size or self.feat_size
self.latent_len = latent_size
self.latent_len = latent_len
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))

self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm)
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(drop)

self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))

Expand All @@ -425,14 +431,31 @@ def init_weights(self):
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)

def forward(self, x):
B, N, _ = x.shape
B, N, C = x.shape

if self.pos_embed is not None:
# FIXME interpolate
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)

latent_q = self.latent.expand(B, -1, -1)
x = self.attn(torch.cat([latent_q, x], dim=1))
q_latent = self.latent.expand(B, -1, -1)
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)

kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
k, v = kv.unbind(0)

q, k = self.q_norm(q), self.k_norm(k)

if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
x = self.proj(x)
x = self.proj_drop(x)

x = x + self.mlp(self.norm(x))

# optional pool if latent seq_len > 1 and pooled output is desired
Expand Down Expand Up @@ -579,7 +602,7 @@ def __init__(
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()

# Classifier Head
if global_pool == 'pool':
if global_pool == 'map':
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
Expand Down Expand Up @@ -932,14 +955,16 @@ def _n2p(w, t=True):
block_prefix = f'{prefix}MAPHead_0/'
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
model.attn_pool.kv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
model.attn_pool.kv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
model.attn_pool.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
model.attn_pool.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
model.attn_pool.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
model.attn_pool.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
Expand All @@ -956,11 +981,11 @@ def _n2p(w, t=True):
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))


def _convert_openai_clip(state_dict, model):
Expand Down Expand Up @@ -1613,6 +1638,12 @@ def _cfg(url='', **kwargs):
custom_load=True,
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_256': _cfg(
file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
custom_load=True,
input_size=(3, 256, 256),
# hf_hub_id='timm/',
num_classes=0),
})


Expand Down Expand Up @@ -2249,6 +2280,17 @@ def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer
return model


@register_model
def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model



@register_model
def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
Expand Down

0 comments on commit b9dde58

Please sign in to comment.