diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index caec5e696e..8ce3687ef2 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -37,7 +37,8 @@ from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc -from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords +from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \ + resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \ build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \ FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat diff --git a/timm/layers/interpolate.py b/timm/layers/interpolate.py new file mode 100644 index 0000000000..adba9342ec --- /dev/null +++ b/timm/layers/interpolate.py @@ -0,0 +1,68 @@ +""" Interpolation helpers for timm layers + +RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations +Copyright Shane Barratt, Apache 2.0 license +""" +import torch +from itertools import product + + +class RegularGridInterpolator: + """ Interpolate data defined on a rectilinear grid with even or uneven spacing. + Produces similar results to scipy RegularGridInterpolator or interp2d + in 'linear' mode. + + Taken from https://github.com/sbarratt/torch_interpolations + """ + + def __init__(self, points, values): + self.points = points + self.values = values + + assert isinstance(self.points, tuple) or isinstance(self.points, list) + assert isinstance(self.values, torch.Tensor) + + self.ms = list(self.values.shape) + self.n = len(self.points) + + assert len(self.ms) == self.n + + for i, p in enumerate(self.points): + assert isinstance(p, torch.Tensor) + assert p.shape[0] == self.values.shape[i] + + def __call__(self, points_to_interp): + assert self.points is not None + assert self.values is not None + + assert len(points_to_interp) == len(self.points) + K = points_to_interp[0].shape[0] + for x in points_to_interp: + assert x.shape[0] == K + + idxs = [] + dists = [] + overalls = [] + for p, x in zip(self.points, points_to_interp): + idx_right = torch.bucketize(x, p) + idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1 + idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1) + dist_left = x - p[idx_left] + dist_right = p[idx_right] - x + dist_left[dist_left < 0] = 0. + dist_right[dist_right < 0] = 0. + both_zero = (dist_left == 0) & (dist_right == 0) + dist_left[both_zero] = dist_right[both_zero] = 1. + + idxs.append((idx_left, idx_right)) + dists.append((dist_left, dist_right)) + overalls.append(dist_left + dist_right) + + numerator = 0. + for indexer in product([0, 1], repeat=self.n): + as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)] + bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)] + numerator += self.values[as_s] * \ + torch.prod(torch.stack(bs_s), dim=0) + denominator = torch.prod(torch.stack(overalls), dim=0) + return numerator / denominator diff --git a/timm/layers/pos_embed_rel.py b/timm/layers/pos_embed_rel.py index 5cb3d0f4dd..dc4377d667 100644 --- a/timm/layers/pos_embed_rel.py +++ b/timm/layers/pos_embed_rel.py @@ -3,15 +3,19 @@ Hacked together by / Copyright 2022 Ross Wightman """ import math +import os from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from .interpolate import RegularGridInterpolator from .mlp import Mlp from .weight_init import trunc_normal_ +_USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0 + def gen_relative_position_index( q_size: Tuple[int, int], @@ -20,51 +24,219 @@ def gen_relative_position_index( ) -> torch.Tensor: # Adapted with significant modifications from Swin / BeiT codebases # get pair-wise relative position index for each token inside the window - if k_size is None: - coords = torch.stack( - torch.meshgrid([ - torch.arange(q_size[0]), - torch.arange(q_size[1]) - ]) - ).flatten(1) # 2, Wh, Ww - relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 - num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3 - else: - # FIXME different q vs k sizes is a WIP, need to better offset the two grids? - q_coords = torch.stack( - torch.meshgrid([ - torch.arange(q_size[0]), - torch.arange(q_size[1]) - ]) - ).flatten(1) # 2, Wh, Ww - k_coords = torch.stack( - torch.meshgrid([ - torch.arange(k_size[0]), - torch.arange(k_size[1]) - ]) - ).flatten(1) - relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 - # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0 - # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1 - # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1 - # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw - num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3 - - _, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0) + assert k_size is None, 'Different q & k sizes not currently supported' # FIXME + + coords = torch.stack( + torch.meshgrid([ + torch.arange(q_size[0]), + torch.arange(q_size[1]) + ]) + ).flatten(1) # 2, Wh, Ww + relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 + relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += q_size[1] - 1 + relative_coords[:, :, 0] *= 2 * q_size[1] - 1 + num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + + # else: + # # FIXME different q vs k sizes is a WIP, need to better offset the two grids? + # q_coords = torch.stack( + # torch.meshgrid([ + # torch.arange(q_size[0]), + # torch.arange(q_size[1]) + # ]) + # ).flatten(1) # 2, Wh, Ww + # k_coords = torch.stack( + # torch.meshgrid([ + # torch.arange(k_size[0]), + # torch.arange(k_size[1]) + # ]) + # ).flatten(1) + # relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww + # relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2 + # relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0 + # relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1 + # relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1 + # relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw + # num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3 + + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww if class_token: # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias # NOTE not intended or tested with MLP log-coords relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0]) - relative_position_index[0, 0:] = num_relative_distance - 3 - relative_position_index[0:, 0] = num_relative_distance - 2 - relative_position_index[0, 0] = num_relative_distance - 1 + relative_position_index[0, 0:] = num_relative_distance + relative_position_index[0:, 0] = num_relative_distance + 1 + relative_position_index[0, 0] = num_relative_distance + 2 return relative_position_index.contiguous() +def resize_rel_pos_bias_table_simple( + rel_pos_bias, + new_window_size: Tuple[int, int], + new_bias_shape: Tuple[int, ...], +): + dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1) + if rel_pos_bias.ndim == 3: + # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported + _, dst_h, dst_w = new_bias_shape + num_attn_heads, src_h, src_w = rel_pos_bias.shape + assert dst_h == dst_size[0] and dst_w == dst_size[1] + if src_h != dst_h or src_w != dst_w: + rel_pos_bias = torch.nn.functional.interpolate( + rel_pos_bias.unsqueeze(0), + size=dst_size, + mode="bicubic", + align_corners=False, + ).squeeze(0) + else: + assert rel_pos_bias.ndim == 2 + # (num_pos, num_heads) (aka flat) bias shape + dst_num_pos, _ = new_bias_shape + src_num_pos, num_attn_heads = rel_pos_bias.shape + num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1]) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed + + if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]: + if num_extra_tokens: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + else: + extra_tokens = None + + rel_pos_bias = torch.nn.functional.interpolate( + rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])), + size=dst_size, + mode="bicubic", + align_corners=False, + ).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1) + + if extra_tokens is not None: + rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + + return rel_pos_bias + + +def resize_rel_pos_bias_table( + rel_pos_bias, + new_window_size: Tuple[int, int], + new_bias_shape: Tuple[int, ...], +): + """ Resize relative position bias table using more advanced interpolation. + + Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc). + + https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351 + + Args: + rel_pos_bias: + new_window_size: + new_bias_shape: + + Returns: + + """ + if _USE_SCIPY: + from scipy import interpolate + + dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1) + if rel_pos_bias.ndim == 3: + # TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported + num_extra_tokens = 0 + _, dst_h, dst_w = new_bias_shape + assert dst_h == dst_size[0] and dst_w == dst_size[1] + num_attn_heads, src_h, src_w = rel_pos_bias.shape + src_size = (src_h, src_w) + has_flat_shape = False + else: + assert rel_pos_bias.ndim == 2 + # (num_pos, num_heads) (aka flat) bias shape + dst_num_pos, _ = new_bias_shape + src_num_pos, num_attn_heads = rel_pos_bias.shape + num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1]) + src_size = int((src_num_pos - num_extra_tokens) ** 0.5) + src_size = (src_size, src_size) + has_flat_shape = True + + if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]: + # print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1])) + if num_extra_tokens: + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + else: + extra_tokens = None + + def geometric_progression(a, r, n): + return a * (1.0 - r ** n) / (1.0 - r) + + def _calc(src, dst): + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src // 2) + if gp > dst // 2: + right = q + else: + left = q + + dis = [] + cur = 1 + for i in range(src // 2): + dis.append(cur) + cur += q ** (i + 1) + r_ids = [-_ for _ in reversed(dis)] + return r_ids + [0] + dis + + y = _calc(src_size[0], dst_size[0]) + x = _calc(src_size[1], dst_size[1]) + yx = [torch.tensor(y), torch.tensor(x)] + # print("Original positions = %s" % str(x)) + + ty = dst_size[0] // 2.0 + tx = dst_size[1] // 2.0 + dy = torch.arange(-ty, ty + 0.1, 1.0) + dx = torch.arange(-tx, tx + 0.1, 1.0) + dyx = torch.meshgrid([dy, dx]) + # print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + for i in range(num_attn_heads): + if has_flat_shape: + z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float() + else: + z = rel_pos_bias[i, :, :].float() + + if _USE_SCIPY: + # Original beit code uses scipy w/ cubic interpolation + f = interpolate.interp2d(x, y, z.numpy(), kind='cubic') + r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device) + else: + # Without scipy dependency, I've found a reasonably simple impl + # that supports uneven spaced interpolation pts with 'linear' interp. + # Results are comparable to scipy for model accuracy in most cases. + f = RegularGridInterpolator(yx, z) + r = f(dyx).contiguous().to(rel_pos_bias.device) + + if has_flat_shape: + r = r.view(-1, 1) + all_rel_pos_bias.append(r) + + if has_flat_shape: + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + else: + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0) + + if extra_tokens is not None: + assert has_flat_shape + rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) + + return rel_pos_bias + + class RelPosBias(nn.Module): """ Relative Position Bias Adapted from Swin-V1 relative position bias impl, modularized. diff --git a/timm/models/beit.py b/timm/models/beit.py index 3472c0dca2..3863198f12 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -48,6 +48,8 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn +from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table + from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -115,7 +117,7 @@ def __init__( self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter( torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH - self.register_buffer("relative_position_index", gen_relative_position_index(window_size)) + self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False) else: self.window_size = None self.relative_position_bias_table = None @@ -504,11 +506,46 @@ def _cfg(url='', **kwargs): }) -def _beit_checkpoint_filter_fn(state_dict, model): - if 'module' in state_dict: - # beit v2 didn't strip module - state_dict = state_dict['module'] - return checkpoint_filter_fn(state_dict, model) +def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True): + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('module', state_dict) + # beit v2 didn't strip module + + out_dict = {} + for k, v in state_dict.items(): + if 'relative_position_index' in k: + continue + if 'patch_embed.proj.weight' in k: + O, I, H, W = model.patch_embed.proj.weight.shape + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + num_prefix_tokens = 1 + v = resample_abs_pos_embed( + v, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + out_dict[k] = v + return out_dict def _create_beit(variant, pretrained=False, **kwargs): diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index e3cf7adde4..12709f5818 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -48,7 +48,7 @@ from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert -from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn +from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import named_apply, checkpoint_seq @@ -186,9 +186,9 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): attn_bias = shared_rel_pos x = torch.nn.functional.scaled_dot_product_attention( - q.transpose(-1, -2), - k.transpose(-1, -2), - v.transpose(-1, -2), + q.transpose(-1, -2).contiguous(), + k.transpose(-1, -2).contiguous(), + v.transpose(-1, -2).contiguous(), attn_mask=attn_bias, dropout_p=self.attn_drop.p, ).transpose(-1, -2).reshape(B, -1, H, W) @@ -1790,6 +1790,15 @@ def checkpoint_filter_fn(state_dict, model: nn.Module): model_state_dict = model.state_dict() out_dict = {} for k, v in state_dict.items(): + if k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel(): # adapt between conv2d / linear layers assert v.ndim in (2, 4) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 2cca973673..a96c69548c 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -24,7 +24,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import PatchEmbed, Mlp, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ - _assert, use_fused_attn + _assert, use_fused_attn, resize_rel_pos_bias_table from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq, named_apply @@ -38,23 +38,28 @@ _int_or_tuple_2_t = Union[int, Tuple[int, int]] -def window_partition(x, window_size: int): +def window_partition( + x: torch.Tensor, + window_size: Tuple[int, int], +) -> torch.Tensor: """ + Partition into non-overlapping windows with padding if needed. Args: - x: (B, H, W, C) - window_size (int): window size + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. Returns: - windows: (num_windows*B, window_size, window_size, C) + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition """ B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) return windows @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size: int, H: int, W: int): +def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int): """ Args: windows: (num_windows*B, window_size, window_size, C) @@ -66,7 +71,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): x: (B, H, W, C) """ C = windows.shape[-1] - x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C) + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) return x @@ -124,7 +129,7 @@ def __init__( self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) # get pair-wise relative position index for each token inside the window - self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w)) + self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False) self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -218,14 +223,11 @@ def __init__( super().__init__() self.dim = dim self.input_resolution = input_resolution - self.window_size = window_size - self.shift_size = shift_size + ws, ss = self._calc_window_shift(window_size, shift_size) + self.window_size: Tuple[int, int] = ws + self.shift_size: Tuple[int, int] = ss + self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" self.norm1 = norm_layer(dim) self.attn = WindowAttention( @@ -237,8 +239,8 @@ def __init__( attn_drop=attn_drop, proj_drop=proj_drop, ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp( in_features=dim, @@ -246,66 +248,81 @@ def __init__( act_layer=act_layer, drop=proj_drop, ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - if self.shift_size > 0: + if any(self.shift_size): # calculate attention mask for SW-MSA H, W = self.input_resolution + H = math.ceil(H / self.window_size[0]) * self.window_size[0] + W = math.ceil(W / self.window_size[1]) * self.window_size[1] img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 cnt = 0 for h in ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)): + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): for w in ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)): + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): img_mask[:, h, w, :] = cnt cnt += 1 - mask_windows = window_partition(img_mask, self.window_size) # num_win, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None - self.register_buffer("attn_mask", attn_mask) - def forward(self, x): - B, H, W, C = x.shape - _assert(H == self.input_resolution[0], "input feature has wrong size") - _assert(W == self.input_resolution[1], "input feature has wrong size") + self.register_buffer("attn_mask", attn_mask, persistent=False) + + def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = to_2tuple(target_shift_size) + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return tuple(window_size), tuple(shift_size) - shortcut = x - x = self.norm1(x) + def _attn(self, x): + B, H, W, C = x.shape # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + has_shift = any(self.shift_size) + if has_shift: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) else: shifted_x = x + # pad for resolution not divisible by window size + pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + # partition windows - x_windows = window_partition(shifted_x, self.window_size) # num_win*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # num_win*B, window_size*window_size, C + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_win*B, window_size*window_size, C + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + shifted_x = shifted_x[:, :H, :W, :].contiguous() # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + if has_shift: + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) else: x = shifted_x + return x - # FFN - x = shortcut + self.drop_path(x) - + def forward(self, x): + B, H, W, C = x.shape + x = x + self.drop_path1(self._attn(self.norm1(x))) x = x.reshape(B, -1, C) - x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x + self.drop_path2(self.mlp(self.norm2(x))) x = x.reshape(B, H, W, C) return x @@ -385,6 +402,8 @@ def __init__( self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution self.depth = depth self.grad_checkpointing = False + window_size = to_2tuple(window_size) + shift_size = tuple([w // 2 for w in window_size]) # patch merging layer if downsample: @@ -405,7 +424,7 @@ def __init__( num_heads=num_heads, head_dim=head_dim, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, + shift_size=0 if (i % 2 == 0) else shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop, @@ -499,7 +518,11 @@ def __init__( # build layers head_dim = to_ntuple(self.num_layers)(head_dim) - window_size = to_ntuple(self.num_layers)(window_size) + if not isinstance(window_size, (list, tuple)): + window_size = to_ntuple(self.num_layers)(window_size) + elif len(window_size) == 2: + window_size = (window_size,) * self.num_layers + assert len(window_size) == self.num_layers mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] layers = [] @@ -598,15 +621,30 @@ def forward(self, x): def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" + old_weights = True if 'head.fc.weight' in state_dict: - return state_dict + old_weights = False import re out_dict = {} state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('state_dict', state_dict) for k, v in state_dict.items(): - k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) - k = k.replace('head.', 'head.fc.') + if any([n in k for n in ('relative_position_index', 'attn_mask')]): + continue # skip buffers that should not be persistent + + if k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + + if old_weights: + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') + out_dict[k] = v return out_dict diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index dba74a9a38..eca2ae7939 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -398,6 +398,8 @@ def __init__( self.depth = depth self.output_nchw = output_nchw self.grad_checkpointing = False + window_size = to_2tuple(window_size) + shift_size = tuple([w // 2 for w in window_size]) # patch merging / downsample layer if downsample: @@ -413,7 +415,7 @@ def __init__( input_resolution=self.output_resolution, num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, + shift_size=0 if (i % 2 == 0) else shift_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_drop=proj_drop, @@ -568,7 +570,7 @@ def _init_weights(self, m): def no_weight_decay(self): nod = set() for n, m in self.named_modules(): - if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]): + if any([kw in n for kw in ("cpb_mlp", "logit_scale")]): nod.add(n) return nod