diff --git a/examples/benchmarks/bert/src/bert_layers.py b/examples/benchmarks/bert/src/bert_layers.py index 59cd9ff7..8050394f 100644 --- a/examples/benchmarks/bert/src/bert_layers.py +++ b/examples/benchmarks/bert/src/bert_layers.py @@ -39,6 +39,7 @@ import os import sys import warnings +from functools import lru_cache from typing import List, Optional, Tuple, Union # Add folder root to path to allow us to use relative imports regardless of what directory the script is run from @@ -78,11 +79,12 @@ logger = logging.getLogger(__name__) -if flash_attn_qkvpacked_func is not None: + +@lru_cache +def _get_half_dtype() -> torch.dtype: if torch.cuda.is_bf16_supported(): - HALF_DTYPE = torch.bfloat16 - else: - HALF_DTYPE = torch.float16 + return torch.bfloat16 + return torch.float16 class BertEmbeddings(nn.Module): @@ -262,9 +264,9 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, # If FA2 is supported, bfloat16 must be supported # as of FA2 2.4.2. (Turing GPUs not supported) orig_dtype = qkv.dtype - qkv = qkv.to(HALF_DTYPE) + qkv = qkv.to(torch.bfloat16) bias_dtype = bias.dtype - bias = bias.to(HALF_DTYPE) + bias = bias.to(torch.bfloat16) attention = flash_attn_qkvpacked_func( qkv, dropout_p=self.p_dropout, alibi_slopes=slopes) @@ -277,12 +279,13 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, # Triton implementation only supports 0 attention dropout convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: + half = _get_half_dtype() # Triton implementation only supports fp16 and bf16 orig_dtype = qkv.dtype - qkv = qkv.to(HALF_DTYPE) + qkv = qkv.to(half) bias_dtype = bias.dtype - bias = bias.to(HALF_DTYPE) + bias = bias.to(half) attention = flash_attn_qkvpacked_func(qkv, bias) attention = attention.to(orig_dtype) bias = bias.to(bias_dtype)