Skip to content

Commit

Permalink
revert back to old HALF selection
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylion007 committed Jan 4, 2024
1 parent eb93f3b commit 58d9ffe
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions examples/benchmarks/bert/src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import os
import sys
import warnings
from functools import lru_cache
from typing import List, Optional, Tuple, Union

# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
Expand Down Expand Up @@ -78,11 +79,12 @@

logger = logging.getLogger(__name__)

if flash_attn_qkvpacked_func is not None:

@lru_cache
def _get_half_dtype() -> torch.dtype:
if torch.cuda.is_bf16_supported():
HALF_DTYPE = torch.bfloat16
else:
HALF_DTYPE = torch.float16
return torch.bfloat16
return torch.float16


class BertEmbeddings(nn.Module):
Expand Down Expand Up @@ -262,9 +264,9 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
# If FA2 is supported, bfloat16 must be supported
# as of FA2 2.4.2. (Turing GPUs not supported)
orig_dtype = qkv.dtype
qkv = qkv.to(HALF_DTYPE)
qkv = qkv.to(torch.bfloat16)
bias_dtype = bias.dtype
bias = bias.to(HALF_DTYPE)
bias = bias.to(torch.bfloat16)

attention = flash_attn_qkvpacked_func(
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
Expand All @@ -277,12 +279,13 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
if convert_dtype:
half = _get_half_dtype()

# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(HALF_DTYPE)
qkv = qkv.to(half)
bias_dtype = bias.dtype
bias = bias.to(HALF_DTYPE)
bias = bias.to(half)
attention = flash_attn_qkvpacked_func(qkv, bias)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
Expand Down

0 comments on commit 58d9ffe

Please sign in to comment.