From e47130a4c7edd859c306c7c0596c32f040ff675a Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 17 Jun 2024 19:51:28 -0700 Subject: [PATCH] Revert "Optionally use `flash-attn`'s CE loss for metrics (#3394)" This reverts commit 2cf9262e988c7cc4ee107259b98efec0298c5017. revert dat boi --- .github/workflows/pr-cpu.yaml | 2 +- composer/devices/device_gpu.py | 3 - composer/metrics/nlp.py | 22 +------ tests/checkpoint/test_state_dict.py | 6 +- tests/metrics/test_nlp_metrics.py | 89 ----------------------------- 5 files changed, 4 insertions(+), 118 deletions(-) diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 12f471749e..1bdb383823 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -22,7 +22,7 @@ jobs: markers: not daily and not remote and not gpu and not doctest pytest_command: coverage run -m pytest - name: cpu-3.11-2.3 - container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 markers: not daily and not remote and not gpu and not doctest pytest_command: coverage run -m pytest - name: cpu-doctest diff --git a/composer/devices/device_gpu.py b/composer/devices/device_gpu.py index 401368576e..19cb0a774a 100644 --- a/composer/devices/device_gpu.py +++ b/composer/devices/device_gpu.py @@ -12,7 +12,6 @@ import torch.backends.cudnn import torch.cuda import torch.cuda.amp -import torch.distributed as torch_dist import torch.utils.data from composer.devices.device import Device @@ -43,8 +42,6 @@ def __init__( ): if not torch.cuda.is_available(): raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.') - if torch_dist.is_gloo_available(): - DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo' if device_id is None: device_id = dist.get_local_rank() self._device = torch.device(f'cuda:{device_id}') diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index c1562e5936..e6877292cf 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -83,21 +83,7 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index - self.flash_loss_fn = None - try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss - log.debug( - 'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' + - 'to compute LanguageCrossEntropy metric for CUDA tensors, which will be faster.', - ) - self.flash_loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') - except ImportError: - if torch.cuda.is_available(): - log.debug( - 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + - 'to compute LanguageCrossEntropy metric for CUDA tensors, which will be slower.', - ) - self.torch_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum') @@ -118,11 +104,7 @@ def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None: target = target.view(-1) logits = logits.view(target.shape[0], -1) - # Use Flash attn's CE loss function, if available, if inputs are both CUDA tensors. - if self.flash_loss_fn is not None and target.is_cuda and logits.is_cuda: - losses = self.flash_loss_fn(logits, target) - else: - losses = self.torch_loss_fn(logits, target) + losses = self.loss_fn(logits, target) total_items = (target != self.ignore_index).sum() self.total_items += total_items #type: ignore (third-party) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index e010440836..4f719254a7 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -7,7 +7,6 @@ import pytest import torch -import torch.distributed as torch_dist from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim.lr_scheduler import StepLR @@ -440,10 +439,7 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz assert 'model_name' in metadata_sd assert 'dist_backend' in metadata_sd - if torch_dist.is_gloo_available(): - assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo' - else: - assert metadata_sd['dist_backend'] == 'nccl' + assert metadata_sd['dist_backend'] == 'nccl' @pytest.mark.filterwarnings('ignore:SWA has') diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 9b198003d3..7fe854bd96 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -14,7 +14,6 @@ LanguagePerplexity, MaskedAccuracy, ) -from tests.common import device @pytest.mark.parametrize('ignore_index', [-100]) @@ -51,100 +50,12 @@ def test_masked_accuracy(ignore_index, num_classes): assert abs(final_acc - (1.0 / num_classes)) < 0.02 -@device('cpu', 'gpu') @pytest.mark.parametrize('ignore_index', [-100]) @pytest.mark.parametrize('batch_size', [1e2, 1e3]) @pytest.mark.parametrize('sequence_length', [128]) @pytest.mark.parametrize('num_classes', [2, 10]) @pytest.mark.parametrize('minibatch_size', [56, 256, 768]) -@pytest.mark.parametrize('tensor_device', ['cpu', 'gpu']) def test_cross_entropy( - device: str, - batch_size: float, - ignore_index: Optional[int], - sequence_length: int, - num_classes: int, - minibatch_size: int, - tensor_device: str, -): - """Sanity check to make sure that batched CrossEntropyLoss matches the expected performance. - - Generates a predicted distribution from a normal distribution, and a ground truth from a normal distribution. - Verifies Cross Entropy Loss against the baseline performance. - - Args: - device (str): the device to run the test on - batch_size (int): how many samples are in each batch - ignore_index (Optional[int]): if present, the class index to ignore in accuracy calculations. - sequence_length (int): the length of the generated sequence - num_classes (int): the number of classes in the classification task - minibatch_size (int): the minibatch size to simulate for model predictions - tensor_device (str): which device the input tensors to the metric are on - """ - - if device == 'cpu' and tensor_device == 'gpu': - pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.') - - batch_size = int(batch_size) - generated_preds = torch.randn((batch_size, sequence_length, num_classes)) - generated_true = torch.randint(low=0, high=num_classes, size=(batch_size, sequence_length)) - - assert ignore_index is not None - torchmetrics_xent = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) - ce_with_keys_metric = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) - - if tensor_device == 'cpu': - torchmetrics_xent = torchmetrics_xent.to('cpu') - ce_with_keys_metric = ce_with_keys_metric.to('cpu') - elif tensor_device == 'gpu': - torchmetrics_xent = torchmetrics_xent.to('cuda') - ce_with_keys_metric = ce_with_keys_metric.to('cuda') - - if device == 'gpu': - assert torchmetrics_xent.flash_loss_fn is not None - - labels_mask = torch.rand((batch_size, sequence_length)) - labels_mask[labels_mask > 0.8] = 1 - labels_mask[labels_mask <= 0.8] = 0 - labels_mask = labels_mask.bool() - generated_true[labels_mask] = ignore_index - - num_batches = math.ceil(batch_size / minibatch_size) - for batch_idx in range(num_batches): - begin_idx = (batch_idx * minibatch_size) - end_idx = ((batch_idx + 1) * minibatch_size) - preds_subset = generated_preds[begin_idx:end_idx] - true_subset = generated_true[begin_idx:end_idx] - - if tensor_device == 'cpu': - preds_subset = preds_subset.cpu() - true_subset = true_subset.cpu() - elif tensor_device == 'gpu': - preds_subset = preds_subset.cuda() - true_subset = true_subset.cuda() - - torchmetrics_xent.update(preds_subset, true_subset) - ce_with_keys_metric.update( - { - 'logits': preds_subset.view(-1, num_classes), - 'loss': cross_entropy(preds_subset.view(-1, num_classes), true_subset.view(-1)), - }, - true_subset.view(-1), - ) - - torchmetrics_loss = torchmetrics_xent.compute() - ce_with_keys_loss = ce_with_keys_metric.compute() - correct_loss = cross_entropy(generated_preds.view(-1, num_classes), generated_true.view(-1)) - assert torchmetrics_loss == ce_with_keys_loss - assert torch.isclose(correct_loss, torchmetrics_loss) - - -@pytest.mark.parametrize('ignore_index', [-100]) -@pytest.mark.parametrize('batch_size', [1e2, 1e3]) -@pytest.mark.parametrize('sequence_length', [128]) -@pytest.mark.parametrize('num_classes', [2, 10]) -@pytest.mark.parametrize('minibatch_size', [56, 256, 768]) -def test_torch_cpu_cross_entropy( batch_size: float, ignore_index: Optional[int], sequence_length: int,