diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 8ff6387577..00cb36d215 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -213,7 +213,7 @@ def flash_attn_fn( try: from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip except: - raise RuntimeError('Please install flash-attn==1.0.9') + raise RuntimeError('Please install flash-attn==1.0.9 or flash-attn==2.3.2') check_valid_inputs(query, key, value)