Skip to content

Commit

Permalink
Remove PR curve metrics from backward compatibility test and skip tor…
Browse files Browse the repository at this point in the history
…ch 1.13 (mosaicml#2497)
  • Loading branch information
eracah committed Aug 31, 2023
1 parent 111a9f0 commit fe2e264
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from packaging import version
from torch.utils.data import DataLoader
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassAveragePrecision, MulticlassROC
from torchmetrics.classification import MulticlassAccuracy

from composer.algorithms import EMA
from composer.core.state import fsdp_get_optim_state_dict, fsdp_state_dict_type_context
Expand Down Expand Up @@ -332,6 +332,13 @@ def test_fsdp_mixed_with_sync(world_size, tmp_path: pathlib.Path, sync_module_st
def test_fsdp_load_old_checkpoint(world_size, tmp_path: pathlib.Path, precision: str, sharding_strategy: str,
state_dict_type: str, s3_bucket: str, s3_read_only_prefix: str,
composer_version: str):

if version.parse(torch.__version__) >= version.parse('1.13.0') and composer_version not in [
'0.13.5', '0.14.0', '0.14.1'
]:
pytest.skip(
'Composer 0.15.1 and above checkpoints were saved with torch 2 and as a result are not compatible with torch 1.13.'
)
if version.parse(torch.__version__) >= version.parse('2.0.0') and state_dict_type == 'local':
pytest.xfail(
'Loading a torch 1.13 checkpoint with torch 2.0 for state_dict_type local is not backwards compatible. See https://github.com/pytorch/pytorch/issues/102667 for more info'
Expand All @@ -352,13 +359,9 @@ def test_fsdp_load_old_checkpoint(world_size, tmp_path: pathlib.Path, precision:
num_classes = 8 # This parameter setting is very important. Don't change or the test will fail.
train_metrics = MetricCollection([
MulticlassAccuracy(num_classes=num_classes),
MulticlassAveragePrecision(num_classes=num_classes),
MulticlassROC(num_classes=num_classes)
])
val_metrics = MetricCollection([
MulticlassAccuracy(num_classes=num_classes),
MulticlassAveragePrecision(num_classes=num_classes),
MulticlassROC(num_classes=num_classes)
])
else:
train_metrics = None
Expand Down

0 comments on commit fe2e264

Please sign in to comment.