Skip to content

Commit

Permalink
Support dynamic_resize in eva.py models
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Aug 27, 2023
1 parent 9af5a5b commit 8515e28
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
11 changes: 6 additions & 5 deletions timm/layers/pos_embed_sincos.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,12 @@ def __init__(
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.embed = None
self.pos_embed = None
else:
# cache full sin/cos embeddings if shape provided up front
embeds = build_rotary_pos_embed(
Expand All @@ -425,17 +424,19 @@ def __init__(
)

def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
if self.bands is not None and shape is not None:
# rebuild embeddings every call, use if target shape changes
_assert(shape is not None, 'valid shape needed')
embeds = build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
return torch.cat(embeds, -1)
else:
elif self.pos_embed is not None:
return self.pos_embed
else:
assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"

def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
Expand Down
39 changes: 30 additions & 9 deletions timm/models/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def __init__(
use_abs_pos_emb: bool = True,
use_rot_pos_emb: bool = False,
use_post_norm: bool = False,
dynamic_size: bool = False,
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
head_init_scale: float = 0.001,
):
Expand Down Expand Up @@ -406,13 +407,19 @@ def __init__(
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.dynamic_size = dynamic_size
self.grad_checkpointing = False

embed_args = {}
if dynamic_size:
# flatten deferred until after pos embed
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
**embed_args,
)
num_patches = self.patch_embed.num_patches

Expand All @@ -435,7 +442,7 @@ def __init__(
self.rope = RotaryEmbeddingCat(
embed_dim // num_heads,
in_pixels=False,
feat_shape=self.patch_embed.grid_size,
feat_shape=None if dynamic_size else self.patch_embed.grid_size,
ref_feat_shape=ref_feat_shape,
)
else:
Expand Down Expand Up @@ -519,30 +526,44 @@ def reset_classifier(self, num_classes, global_pool=None):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

def forward_features(self, x):
x = self.patch_embed(x)
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.dynamic_size:
B, H, W, C = x.shape
if self.pos_embed is not None:
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=self.num_prefix_tokens,
)
else:
pos_embed = None
x = x.view(B, -1, C)
rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
else:
pos_embed = self.pos_embed
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None

if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

# apply abs position embedding
if self.pos_embed is not None:
x = x + self.pos_embed
if pos_embed is not None:
x = x + pos_embed
x = self.pos_drop(x)

# obtain shared rotary position embedding and apply patch dropout
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
if self.patch_drop is not None:
x, keep_indices = self.patch_drop(x)
if rot_pos_embed is not None and keep_indices is not None:
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
return x, rot_pos_embed

def forward_features(self, x):
x = self.patch_embed(x)
x, rot_pos_embed = self._pos_embed(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, rope=rot_pos_embed)
else:
x = blk(x, rope=rot_pos_embed)

x = self.norm(x)
return x

Expand Down

0 comments on commit 8515e28

Please sign in to comment.