diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index c7beb3d6f3..e850c03409 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -400,13 +400,12 @@ def __init__( temperature=temperature, step=1, ) - print(bands) self.register_buffer( 'bands', bands, persistent=False, ) - self.embed = None + self.pos_embed = None else: # cache full sin/cos embeddings if shape provided up front embeds = build_rotary_pos_embed( @@ -425,17 +424,19 @@ def __init__( ) def get_embed(self, shape: Optional[List[int]] = None): - if self.bands is not None: + if self.bands is not None and shape is not None: # rebuild embeddings every call, use if target shape changes - _assert(shape is not None, 'valid shape needed') embeds = build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, + ref_feat_shape=self.ref_feat_shape, ) return torch.cat(embeds, -1) - else: + elif self.pos_embed is not None: return self.pos_embed + else: + assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands" def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 diff --git a/timm/models/eva.py b/timm/models/eva.py index f0ab9c7224..7235132c91 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -367,6 +367,7 @@ def __init__( use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, use_post_norm: bool = False, + dynamic_size: bool = False, ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None, head_init_scale: float = 0.001, ): @@ -406,13 +407,19 @@ def __init__( self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 + self.dynamic_size = dynamic_size self.grad_checkpointing = False + embed_args = {} + if dynamic_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + **embed_args, ) num_patches = self.patch_embed.num_patches @@ -435,7 +442,7 @@ def __init__( self.rope = RotaryEmbeddingCat( embed_dim // num_heads, in_pixels=False, - feat_shape=self.patch_embed.grid_size, + feat_shape=None if dynamic_size else self.patch_embed.grid_size, ref_feat_shape=ref_feat_shape, ) else: @@ -519,30 +526,44 @@ def reset_classifier(self, num_classes, global_pool=None): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - def forward_features(self, x): - x = self.patch_embed(x) + def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if self.dynamic_size: + B, H, W, C = x.shape + if self.pos_embed is not None: + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + pos_embed = None + x = x.view(B, -1, C) + rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None + else: + pos_embed = self.pos_embed + rot_pos_embed = self.rope.get_embed() if self.rope is not None else None if self.cls_token is not None: x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - - # apply abs position embedding - if self.pos_embed is not None: - x = x + self.pos_embed + if pos_embed is not None: + x = x + pos_embed x = self.pos_drop(x) # obtain shared rotary position embedding and apply patch dropout - rot_pos_embed = self.rope.get_embed() if self.rope is not None else None if self.patch_drop is not None: x, keep_indices = self.patch_drop(x) if rot_pos_embed is not None and keep_indices is not None: rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices) + return x, rot_pos_embed + def forward_features(self, x): + x = self.patch_embed(x) + x, rot_pos_embed = self._pos_embed(x) for blk in self.blocks: if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint(blk, x, rope=rot_pos_embed) else: x = blk(x, rope=rot_pos_embed) - x = self.norm(x) return x