Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating the Flash Attention version to fix cross entropy loss #812

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(self, om_model_config: Union[DictConfig,
if use_flash_attention_2 and not is_flash_v2_installed():
raise ValueError(
'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ 'Please install flash_attn==2.3.2`.')
+ 'Please install flash_attn==2.3.6`.')

requested_attention_implementation = 'flash_attention_2' if use_flash_attention_2 else 'eager'
config = AutoConfig.from_pretrained(
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def flash_attn_fn(
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
except:
raise RuntimeError(
'Please install flash-attn==1.0.9 or flash-attn==2.3.2')
'Please install flash-attn==1.0.9 or flash-attn==2.3.6')

check_valid_inputs(query, key, value)

Expand Down Expand Up @@ -344,7 +344,7 @@ def flash_attn_fn(
window_size=(sliding_window_size, sliding_window_size))
else:
raise RuntimeError(
'flash-attn==1.0.9 or flash-attn==2.3.2 is required.')
'flash-attn==1.0.9 or flash-attn==2.3.6 is required.')

output = bert_padding.pad_input(
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size,
Expand Down
173 changes: 0 additions & 173 deletions llmfoundry/models/layers/cross_entropy_loss.py

This file was deleted.

2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,5 @@ def _validate_config(self) -> None:
from flash_attn.bert_padding import unpad_input, pad_input # type: ignore # yapf: disable # isort: skip
except:
raise ImportError(
'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.2'
'In order to set `use_pad_tok_in_ffn=False`, please install flash-attn==1.0.9 or flash-attn==2.3.6'
)
6 changes: 1 addition & 5 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,11 +972,7 @@ def __init__(
loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
if loss_fn_config == 'fused_crossentropy':
try:
# NOTE: The following is the original import statement from flash_attn library, which we have currently replaced with a copy pasted code from the same library's version 1.0.9. The reason is that using the CE loss from FA v2.3.2 results in an illegal memory access error at long sequence lengths (github issue: https://github.com/Dao-AILab/flash-attention/issues/714).
# from flash_attn.losses.cross_entropy import \
# CrossEntropyLoss as FusedCrossEntropyLoss
# TODO: Once the problem with using FA v2's CE loss at longer sequence lengths is resolved (github issue: https://github.com/Dao-AILab/flash-attention/issues/714), revert back to directly importing the CE loss from FA library.
from llmfoundry.models.layers.cross_entropy_loss import \
from flash_attn.losses.cross_entropy import \
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100)
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,8 @@
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
]
extra_deps['gpu-flash2'] = [
'flash-attn==2.3.2',
'flash-attn==2.3.6',
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
'mosaicml-turbo==0.0.4',
# PyPI does not support direct dependencies, so we remove this line before uploading from PyPI
'xentropy-cuda-lib@git+https://github.com/HazyResearch/[email protected]#subdirectory=csrc/xentropy',
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
]

extra_deps['peft'] = [
Expand Down
16 changes: 5 additions & 11 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,22 +422,16 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str,


@pytest.mark.gpu
@pytest.mark.parametrize('ce_loss_implementation',
['FA_v1_copied', 'FA_imported'])
def test_loss_fn(ce_loss_implementation: str):
def test_loss_fn():
"""Tests the Fused CrossEntropy vs torch.nn.CrossEntropy loss function.

We provide non-zero tolerances to account for small numerics differences
between the two loss implementations.
"""
if ce_loss_implementation == 'FA_imported':
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')
else:
from llmfoundry.models.layers.cross_entropy_loss import \
CrossEntropyLoss as FusedCrossEntropyLoss
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip
except:
pytest.skip('Fused cross entropy was not installed')

# run numerical test in pure fp32
from torch.backends import cuda, cudnn
Expand Down
Loading