Skip to content

Commit

Permalink
Add more advanced interpolation method from BEiT and support non-squa…
Browse files Browse the repository at this point in the history
…re window & image size adaptation for

* beit/beit-v2
* maxxvit/coatnet
* swin transformer
And non-square windows for swin-v2
  • Loading branch information
rwightman committed Aug 8, 2023
1 parent 1dab536 commit c153cd4
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 57 deletions.
3 changes: 2 additions & 1 deletion timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from .patch_embed import PatchEmbed, PatchEmbedWithSize, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed, resample_abs_pos_embed_nhwc
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords, \
resize_rel_pos_bias_table, resize_rel_pos_bias_table_simple
from .pos_embed_sincos import pixel_freq_bands, freq_bands, build_sincos2d_pos_embed, build_fourier_pos_embed, \
build_rotary_pos_embed, apply_rot_embed, apply_rot_embed_cat, apply_rot_embed_list, apply_keep_indices_nlc, \
FourierEmbed, RotaryEmbedding, RotaryEmbeddingCat
Expand Down
68 changes: 68 additions & 0 deletions timm/layers/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
""" Interpolation helpers for timm layers
RegularGridInterpolator from https://github.com/sbarratt/torch_interpolations
Copyright Shane Barratt, Apache 2.0 license
"""
import torch
from itertools import product


class RegularGridInterpolator:
""" Interpolate data defined on a rectilinear grid with even or uneven spacing.
Produces similar results to scipy RegularGridInterpolator or interp2d
in 'linear' mode.
Taken from https://github.com/sbarratt/torch_interpolations
"""

def __init__(self, points, values):
self.points = points
self.values = values

assert isinstance(self.points, tuple) or isinstance(self.points, list)
assert isinstance(self.values, torch.Tensor)

self.ms = list(self.values.shape)
self.n = len(self.points)

assert len(self.ms) == self.n

for i, p in enumerate(self.points):
assert isinstance(p, torch.Tensor)
assert p.shape[0] == self.values.shape[i]

def __call__(self, points_to_interp):
assert self.points is not None
assert self.values is not None

assert len(points_to_interp) == len(self.points)
K = points_to_interp[0].shape[0]
for x in points_to_interp:
assert x.shape[0] == K

idxs = []
dists = []
overalls = []
for p, x in zip(self.points, points_to_interp):
idx_right = torch.bucketize(x, p)
idx_right[idx_right >= p.shape[0]] = p.shape[0] - 1
idx_left = (idx_right - 1).clamp(0, p.shape[0] - 1)
dist_left = x - p[idx_left]
dist_right = p[idx_right] - x
dist_left[dist_left < 0] = 0.
dist_right[dist_right < 0] = 0.
both_zero = (dist_left == 0) & (dist_right == 0)
dist_left[both_zero] = dist_right[both_zero] = 1.

idxs.append((idx_left, idx_right))
dists.append((dist_left, dist_right))
overalls.append(dist_left + dist_right)

numerator = 0.
for indexer in product([0, 1], repeat=self.n):
as_s = [idx[onoff] for onoff, idx in zip(indexer, idxs)]
bs_s = [dist[1 - onoff] for onoff, dist in zip(indexer, dists)]
numerator += self.values[as_s] * \
torch.prod(torch.stack(bs_s), dim=0)
denominator = torch.prod(torch.stack(overalls), dim=0)
return numerator / denominator
244 changes: 208 additions & 36 deletions timm/layers/pos_embed_rel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
import os
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from .interpolate import RegularGridInterpolator
from .mlp import Mlp
from .weight_init import trunc_normal_

_USE_SCIPY = int(os.environ.get('TIMM_USE_SCIPY_INTERP', 0)) > 0


def gen_relative_position_index(
q_size: Tuple[int, int],
Expand All @@ -20,51 +24,219 @@ def gen_relative_position_index(
) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
if k_size is None:
coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1) + 3
else:
# FIXME different q vs k sizes is a WIP, need to better offset the two grids?
q_coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
k_coords = torch.stack(
torch.meshgrid([
torch.arange(k_size[0]),
torch.arange(k_size[1])
])
).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + q_size[1] - 1) + 3

_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
assert k_size is None, 'Different q & k sizes not currently supported' # FIXME

coords = torch.stack(
torch.meshgrid([
torch.arange(q_size[0]),
torch.arange(q_size[1])
])
).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += q_size[1] - 1
relative_coords[:, :, 0] *= 2 * q_size[1] - 1
num_relative_distance = (2 * q_size[0] - 1) * (2 * q_size[1] - 1)

# else:
# # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
# q_coords = torch.stack(
# torch.meshgrid([
# torch.arange(q_size[0]),
# torch.arange(q_size[1])
# ])
# ).flatten(1) # 2, Wh, Ww
# k_coords = torch.stack(
# torch.meshgrid([
# torch.arange(k_size[0]),
# torch.arange(k_size[1])
# ])
# ).flatten(1)
# relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
# relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
# relative_coords[:, :, 0] += max(q_size[0], k_size[0]) - 1 # shift to start from 0
# relative_coords[:, :, 1] += max(q_size[1], k_size[1]) - 1
# relative_coords[:, :, 0] *= k_size[1] + q_size[1] - 1
# relative_position_index = relative_coords.sum(-1) # Qh*Qw, Kh*Kw
# num_relative_distance = (q_size[0] + k_size[0] - 1) * (q_size[1] + k_size[1] - 1) + 3

relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww

if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
relative_position_index[0, 0:] = num_relative_distance
relative_position_index[0:, 0] = num_relative_distance + 1
relative_position_index[0, 0] = num_relative_distance + 2

return relative_position_index.contiguous()


def resize_rel_pos_bias_table_simple(
rel_pos_bias,
new_window_size: Tuple[int, int],
new_bias_shape: Tuple[int, ...],
):
dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
if rel_pos_bias.ndim == 3:
# TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
_, dst_h, dst_w = new_bias_shape
num_attn_heads, src_h, src_w = rel_pos_bias.shape
assert dst_h == dst_size[0] and dst_w == dst_size[1]
if src_h != dst_h or src_w != dst_w:
rel_pos_bias = torch.nn.functional.interpolate(
rel_pos_bias.unsqueeze(0),
size=dst_size,
mode="bicubic",
align_corners=False,
).squeeze(0)
else:
assert rel_pos_bias.ndim == 2
# (num_pos, num_heads) (aka flat) bias shape
dst_num_pos, _ = new_bias_shape
src_num_pos, num_attn_heads = rel_pos_bias.shape
num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
src_size = (src_size, src_size) # FIXME could support non-equal src if argument passed

if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
if num_extra_tokens:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
else:
extra_tokens = None

rel_pos_bias = torch.nn.functional.interpolate(
rel_pos_bias.transpose(1, 0).reshape((1, -1, src_size[0], src_size[1])),
size=dst_size,
mode="bicubic",
align_corners=False,
).view(-1, dst_num_pos - num_extra_tokens).transpose(0, 1)

if extra_tokens is not None:
rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)

return rel_pos_bias


def resize_rel_pos_bias_table(
rel_pos_bias,
new_window_size: Tuple[int, int],
new_bias_shape: Tuple[int, ...],
):
""" Resize relative position bias table using more advanced interpolation.
Modified from code in Microsoft Unilm (https://github.com/microsoft/unilm) repo (BeiT, BeiT-v2, etc).
https://github.com/microsoft/unilm/blob/5255d52de86dad642810f5849dd357769346c1d7/beit/run_class_finetuning.py#L351
Args:
rel_pos_bias:
new_window_size:
new_bias_shape:
Returns:
"""
if _USE_SCIPY:
from scipy import interpolate

dst_size = (new_window_size[0] * 2 - 1, new_window_size[1] * 2 - 1)
if rel_pos_bias.ndim == 3:
# TF maxvit style (num_heads, H, W) bias shape, no extra tokens currently supported
num_extra_tokens = 0
_, dst_h, dst_w = new_bias_shape
assert dst_h == dst_size[0] and dst_w == dst_size[1]
num_attn_heads, src_h, src_w = rel_pos_bias.shape
src_size = (src_h, src_w)
has_flat_shape = False
else:
assert rel_pos_bias.ndim == 2
# (num_pos, num_heads) (aka flat) bias shape
dst_num_pos, _ = new_bias_shape
src_num_pos, num_attn_heads = rel_pos_bias.shape
num_extra_tokens = dst_num_pos - (dst_size[0] * dst_size[1])
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
src_size = (src_size, src_size)
has_flat_shape = True

if src_size[0] != dst_size[0] or src_size[1] != dst_size[1]:
# print("Interpolating position from %dx%d to %dx%d" % (src_size[0], src_size[1], dst_size[0], dst_size[1]))
if num_extra_tokens:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
else:
extra_tokens = None

def geometric_progression(a, r, n):
return a * (1.0 - r ** n) / (1.0 - r)

def _calc(src, dst):
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src // 2)
if gp > dst // 2:
right = q
else:
left = q

dis = []
cur = 1
for i in range(src // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
return r_ids + [0] + dis

y = _calc(src_size[0], dst_size[0])
x = _calc(src_size[1], dst_size[1])
yx = [torch.tensor(y), torch.tensor(x)]
# print("Original positions = %s" % str(x))

ty = dst_size[0] // 2.0
tx = dst_size[1] // 2.0
dy = torch.arange(-ty, ty + 0.1, 1.0)
dx = torch.arange(-tx, tx + 0.1, 1.0)
dyx = torch.meshgrid([dy, dx])
# print("Target positions = %s" % str(dx))

all_rel_pos_bias = []
for i in range(num_attn_heads):
if has_flat_shape:
z = rel_pos_bias[:, i].view(src_size[0], src_size[1]).float()
else:
z = rel_pos_bias[i, :, :].float()

if _USE_SCIPY:
# Original beit code uses scipy w/ cubic interpolation
f = interpolate.interp2d(x, y, z.numpy(), kind='cubic')
r = torch.Tensor(f(dx, dy)).contiguous().to(rel_pos_bias.device)
else:
# Without scipy dependency, I've found a reasonably simple impl
# that supports uneven spaced interpolation pts with 'linear' interp.
# Results are comparable to scipy for model accuracy in most cases.
f = RegularGridInterpolator(yx, z)
r = f(dyx).contiguous().to(rel_pos_bias.device)

if has_flat_shape:
r = r.view(-1, 1)
all_rel_pos_bias.append(r)

if has_flat_shape:
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
else:
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=0)

if extra_tokens is not None:
assert has_flat_shape
rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)

return rel_pos_bias


class RelPosBias(nn.Module):
""" Relative Position Bias
Adapted from Swin-V1 relative position bias impl, modularized.
Expand Down
Loading

0 comments on commit c153cd4

Please sign in to comment.