Skip to content

Commit

Permalink
fix all SDPA dropouts
Browse files Browse the repository at this point in the history
  • Loading branch information
YassineYousfi committed Oct 4, 2023
1 parent 056fca2 commit 26ea51f
Show file tree
Hide file tree
Showing 14 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=rel_pos_bias,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down
2 changes: 1 addition & 1 deletion timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(self, x):
if self.fused_attn:
x_cls = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down
2 changes: 1 addition & 1 deletion timm/models/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def forward(
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down
2 changes: 1 addition & 1 deletion timm/models/fastvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down
4 changes: 2 additions & 2 deletions timm/models/maxxvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
k.transpose(-1, -2).contiguous(),
v.transpose(-1, -2).contiguous(),
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
).transpose(-1, -2).reshape(B, -1, H, W)
else:
q = q * self.scale
Expand Down Expand Up @@ -259,7 +259,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down
2 changes: 1 addition & 1 deletion timm/models/metaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def forward(self, x):
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
Expand Down
6 changes: 3 additions & 3 deletions timm/models/nest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.)
def forward(self, x):
"""
x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
"""
"""
B, T, N, C = x.shape
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)

if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
Expand Down Expand Up @@ -330,7 +330,7 @@ def __init__(
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
# number of blocks along edge of image
self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))

# Patch embedding
self.patch_embed = PatchEmbed(
img_size=img_size,
Expand Down
2 changes: 1 addition & 1 deletion timm/models/pvt_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def forward(self, x, feat_size: List[int]):
k, v = kv.unbind(0)

if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
Expand Down
2 changes: 1 addition & 1 deletion timm/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None):
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down
4 changes: 2 additions & 2 deletions timm/models/twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def forward(self, x, size: Size_):
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down Expand Up @@ -172,7 +172,7 @@ def forward(self, x, size: Size_):
if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down
2 changes: 1 addition & 1 deletion timm/models/visformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def forward(self, x):
if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q.contiguous(), k.contiguous(), v.contiguous(),
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
Expand Down
10 changes: 5 additions & 5 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def forward(self, x):
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down Expand Up @@ -285,7 +285,7 @@ def forward(self, x):
if self.fused_attn:
x_attn = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def _cfg(url='', **kwargs):
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),

# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
'vit_small_patch14_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
Expand Down Expand Up @@ -1471,7 +1471,7 @@ def _cfg(url='', **kwargs):
hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),

'vit_huge_patch14_224_ijepa.in1k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
# hf_hub_id='timm/',
Expand Down Expand Up @@ -2080,7 +2080,7 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192

model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
)
model = _create_vision_transformer(
Expand Down
2 changes: 1 addition & 1 deletion timm/models/vision_transformer_relpos.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down
2 changes: 1 addition & 1 deletion timm/models/vision_transformer_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def forward(self, x):
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
Expand Down

0 comments on commit 26ea51f

Please sign in to comment.