From fe2e264a43e01a24d9a8651041a39f9737a5001f Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Thu, 31 Aug 2023 16:08:27 -0700 Subject: [PATCH] Remove PR curve metrics from backward compatibility test and skip torch 1.13 (#2497) --- tests/trainer/test_fsdp_checkpoint.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index e42d8da969..2d9fb50158 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -14,7 +14,7 @@ from packaging import version from torch.utils.data import DataLoader from torchmetrics import MetricCollection -from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision, MulticlassROC +from torchmetrics.classification import MulticlassAccuracy from composer.algorithms import EMA from composer.core.state import fsdp_get_optim_state_dict, fsdp_state_dict_type_context @@ -332,6 +332,13 @@ def test_fsdp_mixed_with_sync(world_size, tmp_path: pathlib.Path, sync_module_st def test_fsdp_load_old_checkpoint(world_size, tmp_path: pathlib.Path, precision: str, sharding_strategy: str, state_dict_type: str, s3_bucket: str, s3_read_only_prefix: str, composer_version: str): + + if version.parse(torch.__version__) >= version.parse('1.13.0') and composer_version not in [ + '0.13.5', '0.14.0', '0.14.1' + ]: + pytest.skip( + 'Composer 0.15.1 and above checkpoints were saved with torch 2 and as a result are not compatible with torch 1.13.' + ) if version.parse(torch.__version__) >= version.parse('2.0.0') and state_dict_type == 'local': pytest.xfail( 'Loading a torch 1.13 checkpoint with torch 2.0 for state_dict_type local is not backwards compatible. See https://github.com/pytorch/pytorch/issues/102667 for more info' @@ -352,13 +359,9 @@ def test_fsdp_load_old_checkpoint(world_size, tmp_path: pathlib.Path, precision: num_classes = 8 # This parameter setting is very important. Don't change or the test will fail. train_metrics = MetricCollection([ MulticlassAccuracy(num_classes=num_classes), - MulticlassAveragePrecision(num_classes=num_classes), - MulticlassROC(num_classes=num_classes) ]) val_metrics = MetricCollection([ MulticlassAccuracy(num_classes=num_classes), - MulticlassAveragePrecision(num_classes=num_classes), - MulticlassROC(num_classes=num_classes) ]) else: train_metrics = None