Skip to content

Commit

Permalink
Another effcientvit (mit) tweak, fix torchscript/fx conflict with aut…
Browse files Browse the repository at this point in the history
…ocast disable
  • Loading branch information
rwightman committed Aug 20, 2023
1 parent dc18cda commit 300f54a
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions timm/models/efficientvit_mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import torch.nn.functional as F

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from ._registry import register_model, generate_default_cfgs
from timm.layers import SelectAdaptivePool2d, create_conv2d
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq
from timm.layers import SelectAdaptivePool2d, create_conv2d
from ._registry import register_model, generate_default_cfgs


def val2list(x: list or tuple or any, repeat_time=1):
Expand Down Expand Up @@ -233,6 +234,14 @@ def __init__(
act_layer=act_layer[1],
)

def _attn(self, q, k, v):
dtype = v.dtype
q, k, v = q.float(), k.float(), v.float()
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
return out.to(dtype)

def forward(self, x):
B, _, H, W = x.shape

Expand All @@ -243,27 +252,28 @@ def forward(self, x):
multi_scale_qkv.append(op(qkv))
multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
q, k, v = multi_scale_qkv.tensor_split(3, dim=-1)
q, k, v = multi_scale_qkv.chunk(3, dim=-1)

# lightweight global attention
q = self.kernel_func(q)
k = self.kernel_func(k)
v = F.pad(v, (0, 1), mode="constant", value=1.)

dtype = v.dtype
q, k, v = q.float(), k.float(), v.float()
with torch.amp.autocast(device_type=v.device.type, enabled=False):
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
out = out.to(dtype)
if not torch.jit.is_scripting():
with torch.amp.autocast(device_type=v.device.type, enabled=False):
out = self._attn(q, k, v)
else:
out = self._attn(q, k, v)

# final projection
out = out.transpose(-1, -2).reshape(B, -1, H, W)
out = self.proj(out)
return out


register_notrace_module(LiteMSA)


class EfficientVitBlock(nn.Module):
def __init__(
self,
Expand Down

0 comments on commit 300f54a

Please sign in to comment.