diff --git a/timm/models/vision_transformer_packed.py b/timm/models/vision_transformer_packed.py index 860f300a80..c2f71f94ba 100644 --- a/timm/models/vision_transformer_packed.py +++ b/timm/models/vision_transformer_packed.py @@ -84,13 +84,21 @@ def add_image(self, tokens, pos_indices): self.total_len += seq_len self.num_images += 1 - def to_tensors(self, max_len, max_packed, return_mask=True): + def to_tensors(self, max_seq_len, max_num_seq): + """ + Args: + max_seq_len: maximum sequence length (pad to this) + max_num_seq: maximum # of sequences (images) packed into one sequence (across the batch) + + Returns: + Tuple of tensors for packed batch of images + """ assert self.total_len > 0 - assert max_len >= self.total_len + assert max_seq_len >= self.total_len device = self.tokens[-1].device dim = self.tokens[-1].shape[-1] - pad_len = max_len - self.total_len - seq_pad = max(0, max_packed - len(self.seq_lens)) + pad_len = max_seq_len - self.total_len + seq_pad = max(0, max_num_seq - len(self.seq_lens)) seq_lens = self.seq_lens + [0] * seq_pad if seq_pad else self.seq_lens seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=device) if pad_len: @@ -104,9 +112,6 @@ def to_tensors(self, max_len, max_packed, return_mask=True): tokens = torch.concat(tokens) pos_indices = torch.concat(pos_indices) seq_ids = torch.concat(seq_ids) - if return_mask: - mask = seq_ids != 0 - return tokens, pos_indices, seq_ids, seq_lens, mask return tokens, pos_indices, seq_ids, seq_lens @@ -173,7 +178,7 @@ def pack_images( max_packed = max(sequence.num_images, max_packed) next_pos += 1 - tensors = [p.to_tensors(max_len=max_seq_len, max_packed=max_packed) for p in packed_sequences] + tensors = [p.to_tensors(max_seq_len=max_seq_len, max_num_seq=max_packed) for p in packed_sequences] o = [torch.stack(t) for t in zip(*tensors)] return tuple(o) @@ -655,12 +660,12 @@ def init_weights(self, mode=''): @torch.jit.ignore def no_weight_decay(self): - return {'embeds.pos_embed', 'embeds.cls_token'} + return {'pos_embed_h', 'pos_embed_w'} @torch.jit.ignore def group_matcher(self, coarse=False): return dict( - stem=r'^embeds', # stem and embed + stem=r'^embeds', # stem and embed # FIXME correct when design finalized blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @@ -675,7 +680,7 @@ def get_classifier(self): def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: - assert global_pool in ('', 'avg', 'token') + assert global_pool in ('', 'avg', 'attn') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -693,7 +698,7 @@ def forward_features( tokens = tokens.unbind(0) if isinstance(tokens, (list, tuple)): - tokens, pos_indices, seq_ids, seq_lens, padding_mask = pack_images( + tokens, pos_indices, seq_ids, seq_lens = pack_images( tokens, self.patch_size, max_grid_size=self.grid_size,