diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 00cb36d215..39fa7162ac 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -213,7 +213,8 @@ def flash_attn_fn( try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: - raise RuntimeError('Please install flash-attn==1.0.9 or flash-attn==2.3.2') + raise RuntimeError( + 'Please install flash-attn==1.0.9 or flash-attn==2.3.2') check_valid_inputs(query, key, value)