From 26ea51f0978013e9419d10e1b8affa26e970e4ec Mon Sep 17 00:00:00 2001 From: Yassine Date: Wed, 4 Oct 2023 14:30:19 -0700 Subject: [PATCH] fix all SDPA dropouts --- timm/models/beit.py | 2 +- timm/models/cait.py | 2 +- timm/models/eva.py | 2 +- timm/models/fastvit.py | 2 +- timm/models/maxxvit.py | 4 ++-- timm/models/metaformer.py | 2 +- timm/models/nest.py | 6 +++--- timm/models/pvt_v2.py | 2 +- timm/models/swin_transformer.py | 2 +- timm/models/twins.py | 4 ++-- timm/models/visformer.py | 2 +- timm/models/vision_transformer.py | 10 +++++----- timm/models/vision_transformer_relpos.py | 2 +- timm/models/vision_transformer_sam.py | 2 +- 14 files changed, 22 insertions(+), 22 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index 3863198f12..663dcc4bd4 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -155,7 +155,7 @@ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): x = F.scaled_dot_product_attention( q, k, v, attn_mask=rel_pos_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/cait.py b/timm/models/cait.py index 4bc7dafc53..40d56061d3 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -50,7 +50,7 @@ def forward(self, x): if self.fused_attn: x_cls = torch.nn.functional.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/eva.py b/timm/models/eva.py index 81bcce525d..68b315386c 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -126,7 +126,7 @@ def forward( x = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index b156ade0ef..b3143ae58b 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -514,7 +514,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p if self.training else 0.0, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 12709f5818..6283443ce5 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -190,7 +190,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): k.transpose(-1, -2).contiguous(), v.transpose(-1, -2).contiguous(), attn_mask=attn_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ).transpose(-1, -2).reshape(B, -1, H, W) else: q = q * self.scale @@ -259,7 +259,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 98a79f598b..7b026a2e43 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -198,7 +198,7 @@ def forward(self, x): if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: attn = (q @ k.transpose(-2, -1)) * self.scale diff --git a/timm/models/nest.py b/timm/models/nest.py index de57ec6e99..d1901cee21 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -59,14 +59,14 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.) def forward(self, x): """ x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim) - """ + """ B, T, N, C = x.shape # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head) qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.) else: q = q * self.scale attn = q @ k.transpose(-2, -1) # (B, H, T, N, N) @@ -330,7 +330,7 @@ def __init__( # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the # number of blocks along edge of image self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0])) - + # Patch embedding self.patch_embed = PatchEmbed( img_size=img_size, diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 00379b158a..16302002eb 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -130,7 +130,7 @@ def forward(self, x, feat_size: List[int]): k, v = kv.unbind(0) if self.fused_attn: - x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.) else: q = q * self.scale attn = q @ k.transpose(-2, -1) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 41b45afb69..34452c7cf1 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -164,7 +164,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/twins.py b/timm/models/twins.py index b96a0234d0..3cd25fb433 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -75,7 +75,7 @@ def forward(self, x, size: Size_): if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale @@ -172,7 +172,7 @@ def forward(self, x, size: Size_): if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 9f5da60be5..953fc64d5e 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -95,7 +95,7 @@ def forward(self, x): if self.fused_attn: x = torch.nn.functional.scaled_dot_product_attention( q.contiguous(), k.contiguous(), v.contiguous(), - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: attn = (q @ k.transpose(-2, -1)) * self.scale diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 10b9296b49..b82b9865f2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -85,7 +85,7 @@ def forward(self, x): if self.fused_attn: x = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale @@ -285,7 +285,7 @@ def forward(self, x): if self.fused_attn: x_attn = F.scaled_dot_product_attention( q, k, v, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale @@ -1151,7 +1151,7 @@ def _cfg(url='', **kwargs): url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - + # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only) 'vit_small_patch14_dinov2.lvd142m': _cfg( url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth', @@ -1471,7 +1471,7 @@ def _cfg(url='', **kwargs): hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - + 'vit_huge_patch14_224_ijepa.in1k': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar', # hf_hub_id='timm/', @@ -2080,7 +2080,7 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 model_args = dict( - patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, + patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU ) model = _create_vision_transformer( diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index ea428587c1..2cd37cfe7e 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -71,7 +71,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index 53c49b071e..59b354fb3d 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -168,7 +168,7 @@ def forward(self, x): x = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=attn_bias, - dropout_p=self.attn_drop.p, + dropout_p=self.attn_drop.p if self.training else 0., ) else: q = q * self.scale