diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index caec5e696e..8ce3687ef2 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -37,7 +37,8 @@ from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc -from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords +from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ + resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat diff --git a/timm/layers/interpolate.py b/timm/layers/interpolate.py new file mode 100644 index 0000000000..adba9342ec --- /dev/null +++ b/timm/layers/interpolate.py @@ -0,0 +1,68 @@ +""" Interpolation helpers for timm layers + +RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations +Copyright Shane Barratt, Apache 2.0 license +""" +import torch +from itertools import product + + +class RegularGridInterpolator: + """ Interpolate data defined on a rectilinear grid with even or uneven spacing. + Produces similar results to scipy RegularGridInterpolator or interp2d + in 'linear' mode. + + Taken from https://github.com/sbarratt/torch_interpolations + """ + + def __init__(self, points, values): + self.points = points + self.values = values + + assert isinstance(self.points, tuple) or isinstance(self.points, list) + assert isinstance(self.values, torch.Tensor) + + self.ms = list(self.values.shape) + self.n = len(self.points) + + assert len(self.ms) == self.n + + for i, p in enumerate(self.points): + assert isinstance(p, torch.Tensor) + assert p.shape[0] == self.values.shape[i] + + def __call__(self, points_to_interp): + assert self.points is not None + assert self.values is not None + + assert len(points_to_interp) == len(self.points) + K = points_to_interp[0].shape[0] + for x in points_to_interp: + assert x.shape[0] == K + + idxs = [] + dists = [] + overalls = [] + for p, x in zip(self.points, points_to_interp): + idx_right = torch.bucketize(x, p) + idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 + idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) + dist_left = x - p[idx_left] + dist_right = p[idx_right] - x + dist_left[dist_left < 0] = 0. + dist_right[dist_right < 0] = 0. + both_zero = (dist_left == 0) & (dist_right == 0) + dist_left[both_zero] = dist_right[both_zero] = 1. + + idxs.append((idx_left, idx_right)) + dists.append((dist_left, dist_right)) + overalls.append(dist_left + dist_right) + + numerator = 0. + for indexer in product([0, 1], repeat=self.n): + as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] + bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] + numerator += self.values[as_s] * \ + torch.prod(torch.stack(bs_s), dim=0) + denominator = torch.prod(torch.stack(overalls), dim=0) + return numerator / denominator diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index 5cb3d0f4dd..dc4377d667 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -3,15 +3,19 @@ Hacked together by / Copyright 2022 Ross Wightman """ import math +import os from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from .interpolate import RegularGridInterpolator from .mlp import Mlp from .weight_init import trunc_normal_ +_USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0 + def gen_relative_position_index( q_size: Tuple[int, int], @@ -20,51 +24,219 @@ def gen_relative_position_index( ) -> torch.Tensor: # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window - if k_size is None: - coords = torch.stack( - torch.meshgrid([ - torch.arange(q_size[0]), - torch.arange(q_size[1]) - ]) - ).flatten(1) # 2, Wh, Ww - relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 - num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3 - else: - # FIXME different q vs k sizes is a WIP, need to better offset the two grids? - q_coords = torch.stack( - torch.meshgrid([ - torch.arange(q_size[0]), - torch.arange(q_size[1]) - ]) - ).flatten(1) # 2, Wh, Ww - k_coords = torch.stack( - torch.meshgrid([ - torch.arange(k_size[0]), - torch.arange(k_size[1]) - ]) - ).flatten(1) - relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 - # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0 - # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1 - # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1 - # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw - num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3 - - _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) + assert k_size is None, 'Different q & k sizes not currently supported' # FIXME + + coords = torch.stack( + torch.meshgrid([ + torch.arange(q_size[0]), + torch.arange(q_size[1]) + ]) + ).flatten(1) # 2, Wh, Ww + relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 + relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += q_size[1] - 1 + relative_coords[:, :, 0] *= 2 * q_size[1] - 1 + num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + + # else: + # # FIXME different q vs k sizes is a WIP, need to better offset the two grids? + # q_coords = torch.stack( + # torch.meshgrid([ + # torch.arange(q_size[0]), + # torch.arange(q_size[1]) + # ]) + # ).flatten(1) # 2, Wh, Ww + # k_coords = torch.stack( + # torch.meshgrid([ + # torch.arange(k_size[0]), + # torch.arange(k_size[1]) + # ]) + # ).flatten(1) + # relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + # relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 + # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0 + # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1 + # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1 + # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw + # num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3 + + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww if class_token: # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias # NOTE not intended or tested with MLP log-coords relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) - relative_position_index[0, 0:] = num_relative_distance - 3 - relative_position_index[0:, 0] = num_relative_distance - 2 - relative_position_index[0, 0] = num_relative_distance - 1 + relative_position_index[0, 0:] = num_relative_distance + relative_position_index[0:, 0] = num_relative_distance + 1 + relative_position_index[0, 0] = num_relative_distance + 2 return relative_position_index.contiguous() +def resize_rel_pos_bias_table_simple( + rel_pos_bias, + new_window_size: Tuple[int, int], + new_bias_shape: Tuple[int, ...], +): + dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1) + if rel_pos_bias.ndim == 3: + # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported + _, dst_h, dst_w = new_bias_shape + num_attn_heads, src_h, src_w = rel_pos_bias.shape + assert dst_h == dst_size[0] and dst_w == dst_size[1] + if src_h != dst_h or src_w != dst_w: + rel_pos_bias = torch.nn.functional.interpolate( + rel_pos_bias.unsqueeze(0), + size=dst_size, + mode="bicubic", + align_corners=False, + ).squeeze(0) + else: + assert rel_pos_bias.ndim == 2 + # (num_pos, num_heads) (aka flat) bias shape + dst_num_pos, _ = new_bias_shape + src_num_pos, num_attn_heads = rel_pos_bias.shape + num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1]) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed + + if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]: + if num_extra_tokens: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + else: + extra_tokens = None + + rel_pos_bias = torch.nn.functional.interpolate( + rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])), + size=dst_size, + mode="bicubic", + align_corners=False, + ).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1) + + if extra_tokens is not None: + rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + + return rel_pos_bias + + +def resize_rel_pos_bias_table( + rel_pos_bias, + new_window_size: Tuple[int, int], + new_bias_shape: Tuple[int, ...], +): + """ Resize relative position bias table using more advanced interpolation. + + Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc). + + https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351 + + Args: + rel_pos_bias: + new_window_size: + new_bias_shape: + + Returns: + + """ + if _USE_SCIPY: + from scipy import interpolate + + dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1) + if rel_pos_bias.ndim == 3: + # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported + num_extra_tokens = 0 + _, dst_h, dst_w = new_bias_shape + assert dst_h == dst_size[0] and dst_w == dst_size[1] + num_attn_heads, src_h, src_w = rel_pos_bias.shape + src_size = (src_h, src_w) + has_flat_shape = False + else: + assert rel_pos_bias.ndim == 2 + # (num_pos, num_heads) (aka flat) bias shape + dst_num_pos, _ = new_bias_shape + src_num_pos, num_attn_heads = rel_pos_bias.shape + num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1]) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + src_size = (src_size, src_size) + has_flat_shape = True + + if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]: + # print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1])) + if num_extra_tokens: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + else: + extra_tokens = None + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + def _calc(src, dst): + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src // 2) + if gp > dst // 2: + right = q + else: + left = q + + dis = [] + cur = 1 + for i in range(src // 2): + dis.append(cur) + cur += q ** (i + 1) + r_ids = [-_ for _ in reversed(dis)] + return r_ids + [0] + dis + + y = _calc(src_size[0], dst_size[0]) + x = _calc(src_size[1], dst_size[1]) + yx = [torch.tensor(y), torch.tensor(x)] + # print("Original positions = %s" % str(x)) + + ty = dst_size[0] // 2.0 + tx = dst_size[1] // 2.0 + dy = torch.arange(-ty, ty + 0.1, 1.0) + dx = torch.arange(-tx, tx + 0.1, 1.0) + dyx = torch.meshgrid([dy, dx]) + # print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + for i in range(num_attn_heads): + if has_flat_shape: + z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float() + else: + z = rel_pos_bias[i, :, :].float() + + if _USE_SCIPY: + # Original beit code uses scipy w/ cubic interpolation + f = interpolate.interp2d(x, y, z.numpy(), kind='cubic') + r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device) + else: + # Without scipy dependency, I've found a reasonably simple impl + # that supports uneven spaced interpolation pts with 'linear' interp. + # Results are comparable to scipy for model accuracy in most cases. + f = RegularGridInterpolator(yx, z) + r = f(dyx).contiguous().to(rel_pos_bias.device) + + if has_flat_shape: + r = r.view(-1, 1) + all_rel_pos_bias.append(r) + + if has_flat_shape: + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + else: + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0) + + if extra_tokens is not None: + assert has_flat_shape + rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + + return rel_pos_bias + + class RelPosBias(nn.Module): """ Relative Position Bias Adapted from Swin-V1 relative position bias impl, modularized. diff --git a/timm/models/beit.py b/timm/models/beit.py index 3472c0dca2..3863198f12 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -48,6 +48,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn +from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table + from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -115,7 +117,7 @@ def __init__( self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH - self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) + self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False) else: self.window_size = None self.relative_position_bias_table = None @@ -504,11 +506,46 @@ def _cfg(url='', **kwargs): }) -def _beit_checkpoint_filter_fn(state_dict, model): - if 'module' in state_dict: - # beit v2 didn't strip module - state_dict = state_dict['module'] - return checkpoint_filter_fn(state_dict, model) +def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('module', state_dict) + # beit v2 didn't strip module + + out_dict = {} + for k, v in state_dict.items(): + if 'relative_position_index' in k: + continue + if 'patch_embed.proj.weight' in k: + O, I, H, W = model.patch_embed.proj.weight.shape + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + num_prefix_tokens = 1 + v = resample_abs_pos_embed( + v, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + out_dict[k] = v + return out_dict def _create_beit(variant, pretrained=False, **kwargs): diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index e3cf7adde4..12709f5818 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -48,7 +48,7 @@ from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert -from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn +from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply, checkpoint_seq @@ -186,9 +186,9 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): attn_bias = shared_rel_pos x = torch.nn.functional.scaled_dot_product_attention( - q.transpose(-1, -2), - k.transpose(-1, -2), - v.transpose(-1, -2), + q.transpose(-1, -2).contiguous(), + k.transpose(-1, -2).contiguous(), + v.transpose(-1, -2).contiguous(), attn_mask=attn_bias, dropout_p=self.attn_drop.p, ).transpose(-1, -2).reshape(B, -1, H, W) @@ -1790,6 +1790,15 @@ def checkpoint_filter_fn(state_dict, model: nn.Module): model_state_dict = model.state_dict() out_dict = {} for k, v in state_dict.items(): + if k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel(): # adapt between conv2d / linear layers assert v.ndim in (2, 4) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 4b9c21c0b0..a96c69548c 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ - _assert, use_fused_attn + _assert, use_fused_attn, resize_rel_pos_bias_table from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply @@ -625,7 +625,6 @@ def checkpoint_filter_fn(state_dict, model): if 'head.fc.weight' in state_dict: old_weights = False import re - current_state_dict = model.state_dict() out_dict = {} state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('state_dict', state_dict) @@ -635,15 +634,12 @@ def checkpoint_filter_fn(state_dict, model): if k.endswith('relative_position_bias_table'): m = model.get_submodule(k[:-29]) - bias_size = tuple([2 * x -1 for x in m.window_size]) - old_len = int(len(v) ** 0.5) # we have to assume pretrained weight is square right now if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: - new_pos_bias = torch.nn.functional.interpolate( - v.transpose(1, 0).reshape(1, -1, old_len, old_len), - size=bias_size, - mode="bicubic", + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, ) - v = new_pos_bias.reshape(-1, bias_size[0] * bias_size[1]).transpose(0, 1) if old_weights: k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 1020217ca8..eca2ae7939 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -570,7 +570,7 @@ def _init_weights(self, m): def no_weight_decay(self): nod = set() for n, m in self.named_modules(): - if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]): + if any([kw in n for kw in ("cpb_mlp", "logit_scale")]): nod.add(n) return nod