diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index e5d4a96b2d..68fc6c75c6 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -14,10 +14,11 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from ._registry import register_model, generate_default_cfgs +from timm.layers import SelectAdaptivePool2d, create_conv2d from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_module from ._manipulate import checkpoint_seq -from timm.layers import SelectAdaptivePool2d, create_conv2d +from ._registry import register_model, generate_default_cfgs def val2list(x: list or tuple or any, repeat_time=1): @@ -233,6 +234,14 @@ def __init__( act_layer=act_layer[1], ) + def _attn(self, q, k, v): + dtype = v.dtype + q, k, v = q.float(), k.float(), v.float() + kv = k.transpose(-1, -2) @ v + out = q @ kv + out = out[..., :-1] / (out[..., -1:] + self.eps) + return out.to(dtype) + def forward(self, x): B, _, H, W = x.shape @@ -243,20 +252,18 @@ def forward(self, x): multi_scale_qkv.append(op(qkv)) multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2) - q, k, v = multi_scale_qkv.tensor_split(3, dim=-1) + q, k, v = multi_scale_qkv.chunk(3, dim=-1) # lightweight global attention q = self.kernel_func(q) k = self.kernel_func(k) v = F.pad(v, (0, 1), mode="constant", value=1.) - dtype = v.dtype - q, k, v = q.float(), k.float(), v.float() - with torch.amp.autocast(device_type=v.device.type, enabled=False): - kv = k.transpose(-1, -2) @ v - out = q @ kv - out = out[..., :-1] / (out[..., -1:] + self.eps) - out = out.to(dtype) + if not torch.jit.is_scripting(): + with torch.amp.autocast(device_type=v.device.type, enabled=False): + out = self._attn(q, k, v) + else: + out = self._attn(q, k, v) # final projection out = out.transpose(-1, -2).reshape(B, -1, H, W) @@ -264,6 +271,9 @@ def forward(self, x): return out +register_notrace_module(LiteMSA) + + class EfficientVitBlock(nn.Module): def __init__( self,