Skip to content

Commit

Permalink
Remove padding calc from pack, minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 14, 2023
1 parent d81f75b commit f93083e
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions timm/models/vision_transformer_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,21 @@ def add_image(self, tokens, pos_indices):
self.total_len += seq_len
self.num_images += 1

def to_tensors(self, max_len, max_packed, return_mask=True):
def to_tensors(self, max_seq_len, max_num_seq):
"""
Args:
max_seq_len: maximum sequence length (pad to this)
max_num_seq: maximum # of sequences (images) packed into one sequence (across the batch)
Returns:
Tuple of tensors for packed batch of images
"""
assert self.total_len > 0
assert max_len >= self.total_len
assert max_seq_len >= self.total_len
device = self.tokens[-1].device
dim = self.tokens[-1].shape[-1]
pad_len = max_len - self.total_len
seq_pad = max(0, max_packed - len(self.seq_lens))
pad_len = max_seq_len - self.total_len
seq_pad = max(0, max_num_seq - len(self.seq_lens))
seq_lens = self.seq_lens + [0] * seq_pad if seq_pad else self.seq_lens
seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=device)
if pad_len:
Expand All @@ -104,9 +112,6 @@ def to_tensors(self, max_len, max_packed, return_mask=True):
tokens = torch.concat(tokens)
pos_indices = torch.concat(pos_indices)
seq_ids = torch.concat(seq_ids)
if return_mask:
mask = seq_ids != 0
return tokens, pos_indices, seq_ids, seq_lens, mask
return tokens, pos_indices, seq_ids, seq_lens


Expand Down Expand Up @@ -173,7 +178,7 @@ def pack_images(
max_packed = max(sequence.num_images, max_packed)
next_pos += 1

tensors = [p.to_tensors(max_len=max_seq_len, max_packed=max_packed) for p in packed_sequences]
tensors = [p.to_tensors(max_seq_len=max_seq_len, max_num_seq=max_packed) for p in packed_sequences]
o = [torch.stack(t) for t in zip(*tensors)]
return tuple(o)

Expand Down Expand Up @@ -655,12 +660,12 @@ def init_weights(self, mode=''):

@torch.jit.ignore
def no_weight_decay(self):
return {'embeds.pos_embed', 'embeds.cls_token'}
return {'pos_embed_h', 'pos_embed_w'}

@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^embeds', # stem and embed
stem=r'^embeds', # stem and embed # FIXME correct when design finalized
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)

Expand All @@ -675,7 +680,7 @@ def get_classifier(self):
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
assert global_pool in ('', 'avg', 'attn')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

Expand All @@ -693,7 +698,7 @@ def forward_features(
tokens = tokens.unbind(0)

if isinstance(tokens, (list, tuple)):
tokens, pos_indices, seq_ids, seq_lens, padding_mask = pack_images(
tokens, pos_indices, seq_ids, seq_lens = pack_images(
tokens,
self.patch_size,
max_grid_size=self.grid_size,
Expand Down

0 comments on commit f93083e

Please sign in to comment.