From b983a050a39151b7ad3bf7e8171051d7d7c89168 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 4 Jun 2024 18:37:23 -0400 Subject: [PATCH 1/3] remove (#3362) --- tests/trainer/test_fsdp_checkpoint.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 30ec369dc6..60a757f3de 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -36,7 +36,6 @@ from composer.utils.object_store import S3ObjectStore from composer.utils.reproducibility import get_rng_state from tests.common import RandomClassificationDataset, deep_compare -from tests.common.compare import deep_compare from tests.common.markers import world_size from tests.trainer.test_checkpoint import TestCheckpointResumption, _assert_checkpoints_equivalent @@ -188,7 +187,6 @@ def _compare_optims_between_state_dicts(state_dict1, state_dict2): state_dict1_moment = state_dict1_moment.to_local() if isinstance(state_dict2_moment, DTensor): state_dict2_moment = state_dict2_moment.to_local() - print(param_name, state_dict1_moment, state_dict2_moment) torch.testing.assert_close(state_dict1_moment, state_dict2_moment) From 6a4303a195207759e891ce133bc80178b0d89367 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Tue, 4 Jun 2024 19:21:22 -0700 Subject: [PATCH 2/3] [ckpt-rewr] Resumption state dict API (#3324) --- composer/checkpoint/__init__.py | 8 +- composer/checkpoint/state_dict.py | 135 +++++++++++++++++++++++++++- tests/checkpoint/test_state_dict.py | 126 ++++++++++++++++++++++++-- 3 files changed, 257 insertions(+), 12 deletions(-) diff --git a/composer/checkpoint/__init__.py b/composer/checkpoint/__init__.py index 84d6c9f4cf..d4b21c790d 100644 --- a/composer/checkpoint/__init__.py +++ b/composer/checkpoint/__init__.py @@ -3,10 +3,16 @@ """Module for checkpointing API.""" -from composer.checkpoint.state_dict import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict +from composer.checkpoint.state_dict import ( + get_metadata_state_dict, + get_model_state_dict, + get_optim_state_dict, + get_resumption_state_dict, +) __all__ = [ 'get_model_state_dict', 'get_optim_state_dict', 'get_metadata_state_dict', + 'get_resumption_state_dict', ] diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index f66193ac8d..a20baaf165 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -6,21 +6,32 @@ import fnmatch import logging import sys -from typing import Any, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union + +from torch.utils.data import DataLoader, Dataset + +from composer.core.data_spec import DataSpec + +if TYPE_CHECKING: + from composer.core.evaluator import Evaluator import torch from packaging import version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader +from composer.core.evaluator import Evaluator +from composer.core.state import State +from composer.core.time import Timestamp from composer.devices import Device from composer.models import ComposerModel, HuggingFaceModel -from composer.utils import STR_TO_DTYPE, dist, get_composer_env_dict +from composer.utils import STR_TO_DTYPE, dist, get_composer_env_dict, reproducibility log = logging.getLogger(__name__) -__all__ = ['get_model_state_dict', 'get_optim_state_dict'] +__all__ = ['get_model_state_dict', 'get_optim_state_dict', 'get_metadata_state_dict', 'get_resumption_state_dict'] def get_model_state_dict( @@ -158,6 +169,93 @@ def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_st return model_state_dict +def get_resumption_state_dict(state: State) -> Dict[str, Any]: + """Generate the state dict for any objects needed for resumption. + + This includes: + * timestamp + * scheduler + * dataset_state + * scaler + * rank_zero_seed + * callbacks + * algorithms + + Returns: + The state dict containing the objects needed for resumption. + """ + resumption_state_dict = {} + resumption_state_dict['dataset_state'] = get_dataset_state_dict( + state.train_dataloader, + state.timestamp, + ) + resumption_state_dict['timestamp'] = state.timestamp.state_dict() + + scheduler_state_dict = _make_state_dict_for_list_of_objects(state.schedulers) + if scheduler_state_dict != {}: + resumption_state_dict['schedulers'] = scheduler_state_dict + + # Use list of tuples to account for duplicates + callbacks_state_dict = _make_state_dict_for_list_of_objects(state.callbacks, use_list_of_tuples=True) + if callbacks_state_dict != {}: + resumption_state_dict['callbacks'] = callbacks_state_dict + + # Use list of tuples to preserve order. + algorithms_state_dict = _make_state_dict_for_list_of_objects(state.algorithms, use_list_of_tuples=True) + if algorithms_state_dict != {}: + resumption_state_dict['algorithms'] = algorithms_state_dict + + if state.scaler is not None: + scaler_sd = _make_state_dict_for_list_of_objects(state.scaler) + if scaler_sd != {}: + resumption_state_dict['scaler'] = state.scaler.state_dict() + + resumption_state_dict['rank_zero_seed'] = state.rank_zero_seed + resumption_state_dict['run_name'] = state.run_name + resumption_state_dict['rng'] = reproducibility.get_rng_state() + + return resumption_state_dict + + +def _make_state_dict_for_list_of_objects(objects: Union[Sequence[Any], Any], + use_list_of_tuples=False) -> Union[Dict[str, Any], List]: + object_list = [] + object_dict = {} + if not isinstance(objects, Sequence): + objects = [objects] + for obj in objects: + if not hasattr(obj, 'state_dict') or obj.state_dict() == {}: + continue + if use_list_of_tuples: + object_list.append((type(obj).__qualname__, obj.state_dict())) + else: + object_dict[type(obj).__qualname__] = obj.state_dict() + if use_list_of_tuples: + return object_list + else: + return object_dict + + +def get_dataset_state_dict( + train_dataloader: Optional[Union[DataLoader, Iterable]], + timestamp: Timestamp, +) -> Dict[str, Any]: + """Collect the state dict(s) of our train and eval dataset(s). + + Returns: + Dict[str, Any]: The state dict(s). + """ + dataset_state_dict = { + 'train': None, + } + dataset = _dataset_of(train_dataloader) + if hasattr(dataset, 'state_dict'): + num_samples = int(timestamp.sample_in_epoch.value) + dataset_state_dict['train'] = dataset.state_dict(num_samples, True) # pyright: ignore + + return dataset_state_dict + + def _get_optim_state_dict_with_fsdp_context_manager( model: nn.Module, optimizer: torch.optim.Optimizer, @@ -356,3 +454,34 @@ def get_metadata_state_dict( metadata_state_dict['precision'] = 'fp32' return metadata_state_dict + + +def _dataset_of(dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]: + """Get the dataset contained by the given dataloader-like object. + + Args: + dataloader (Evaluator | DataSpec | DataLoader | Iterable, optional): The dataloader, wrapped dataloader, or + generic python iterable to get the dataset of, if applicable. + + Returns: + Dataset: Its dataset, if there is one. + """ + from composer.core.evaluator import Evaluator + + # If it's None, no dataset for you. + if dataloader is None: + return None + + # An Evaluator is a dataloader wrapped with metrics. Unwrap its dataloader. + if isinstance(dataloader, Evaluator): + dataloader = dataloader.dataloader + + # A DataSpec is a dataloader wrapped with an on-device transform. Unwrap its dataloader. + if isinstance(dataloader, DataSpec): + dataloader = dataloader.dataloader + + # If what we now have is an actual DataLoader, return its dataset. If not, return None. + if isinstance(dataloader, DataLoader): + return dataloader.dataset + else: + return None diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 63f9996e02..ee53a36ff9 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -1,17 +1,29 @@ # Copyright 2024 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import datetime from typing import Any, Dict +from unittest.mock import MagicMock import pytest import torch from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim import adam - -from composer.checkpoint import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict -from composer.devices import DeviceGPU -from composer.utils import dist +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import DataLoader + +from composer.algorithms import SWA +from composer.callbacks import SpeedMonitor +from composer.checkpoint import ( + get_metadata_state_dict, + get_model_state_dict, + get_optim_state_dict, + get_resumption_state_dict, +) +from composer.core import State +from composer.devices import DeviceCPU, DeviceGPU +from composer.utils import dist, reproducibility from tests.common.compare import deep_compare from tests.common.markers import world_size from tests.common.models import EvenSimplerMLP, SimpleComposerMLP, configure_tiny_gpt2_hf_model @@ -242,6 +254,7 @@ def _init_model_and_optimizer( take_step=True, use_fsdp=False, tensor_type='sharded_tensor', + device='cuda', ): model, loss_fn = _init_model( use_composer_model, @@ -250,6 +263,7 @@ def _init_model_and_optimizer( num_features=num_features, use_fsdp=use_fsdp, tensor_type=tensor_type, + device=device, ) optimizer = _init_optimizer( @@ -260,6 +274,7 @@ def _init_model_and_optimizer( batch_size=batch_size, num_features=num_features, take_step=take_step, + device=device, ) return model, optimizer @@ -271,13 +286,14 @@ def _init_model( batch_size=5, num_features=8, use_fsdp=False, + device='cuda', tensor_type='sharded_tensor', ): if use_composer_model: - model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device='cuda') + model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device=device) loss_fn = model._loss_fn else: - model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device='cuda') + model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device=device) loss_fn = torch.nn.CrossEntropyLoss() if use_fsdp: @@ -307,9 +323,10 @@ def _init_optimizer( batch_size=5, num_features=8, take_step=True, + device='cuda', ): - inputs = torch.randn(batch_size, num_features, device='cuda') - targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device='cuda', dtype=torch.long) + inputs = torch.randn(batch_size, num_features, device=device) + targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device, dtype=torch.long) batch = (inputs, targets) if use_composer_model else inputs optimizer = adam.Adam(model.parameters()) outputs = model(batch) @@ -514,3 +531,96 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz assert 'dist_backend' in metadata_sd assert metadata_sd['dist_backend'] == 'nccl' + + +@pytest.mark.filterwarnings('ignore:SWA has') +def test_get_resumption_state_dict(): + + model, optimizer = _init_model_and_optimizer(use_composer_model=True, take_step=True, device='cpu') + + rank_zero_seed = 10 + run_name = 'test_run' + device = DeviceCPU() + test_dataset_sd = {'foo': 0} + dataloader = MagicMock(spec=DataLoader) + dataloader.dataset = MagicMock() + dataloader.dataset.state_dict = MagicMock(return_value=test_dataset_sd) + swa = SWA() + state = State( + model=model, + rank_zero_seed=rank_zero_seed, + run_name=run_name, + device=device, + train_dataloader=dataloader, + algorithms=[swa], + callbacks=[SpeedMonitor(), SpeedMonitor()], + ) + state.schedulers = StepLR(optimizer=optimizer, step_size=2) + rsd = get_resumption_state_dict(state) + + assert rsd['rank_zero_seed'] == rank_zero_seed + assert rsd['run_name'] == run_name + assert 'timestamp' in rsd + assert rsd['timestamp'] == { + 'iteration': 0, + 'epoch': 0, + 'batch': 0, + 'sample': 0, + 'token': 0, + 'epoch_in_iteration': 0, + 'batch_in_epoch': 0, + 'sample_in_epoch': 0, + 'token_in_epoch': 0, + 'total_wct': datetime.timedelta(0), + 'iteration_wct': datetime.timedelta(0), + 'epoch_wct': datetime.timedelta(0), + 'batch_wct': datetime.timedelta(0), + } + assert rsd['dataset_state'] == {'train': test_dataset_sd} + dict(rsd['algorithms'])['SWA'].pop('repr') + assert rsd['algorithms'] == [ + ( + 'SWA', + { + 'swa_model': None, + 'swa_completed': False, + 'swa_started': False, + 'swa_scheduler': None, + 'step_counter': 0, + }, + ), + ] + assert rsd['callbacks'] == [('SpeedMonitor', {'total_eval_wct': 0.0}), ('SpeedMonitor', {'total_eval_wct': 0.0})] + + +@pytest.mark.gpu +def test_get_resumption_state_dict_gpu(): + if version.parse(torch.__version__) >= version.parse('2.3.0'): + from torch.amp.grad_scaler import GradScaler + else: + from torch.cuda.amp.grad_scaler import GradScaler + + model, _ = _init_model_and_optimizer(use_composer_model=True, take_step=False, device='cuda') + + rank_zero_seed = 10 + run_name = 'test_run' + device = DeviceCPU() + test_dataset_sd = {'test': 0} + dataloader = MagicMock() + dataloader.dataset = MagicMock() + dataloader.dataset.state_dict = MagicMock(return_value=test_dataset_sd) + state = State( + model=model, + rank_zero_seed=rank_zero_seed, + run_name=run_name, + device=device, + scaler=GradScaler(), + ) + rsd = get_resumption_state_dict(state) + assert 'scaler' in rsd + assert set( + rsd['scaler'].keys(), + ) == {'scale', 'growth_factor', 'backoff_factor', 'growth_interval', '_growth_tracker'} + + assert 'rng' in rsd + deep_compare(rsd['rng'], reproducibility.get_rng_state()) From fe6140ec02e9a539332280b97b30bd1b470756de Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Wed, 5 Jun 2024 12:07:57 -0400 Subject: [PATCH 3/3] Revert "Autoresume Validation with Max Duration (#3358)" (#3364) This reverts commit f0eae8a72535c264c18b56072c734693b65689bc. --- composer/trainer/trainer.py | 16 +-------------- tests/trainer/test_checkpoint.py | 34 +++++++------------------------- 2 files changed, 8 insertions(+), 42 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index cb42094f37..eb5080eaee 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1723,12 +1723,9 @@ def __init__( # Load Checkpoint self._rng_state = None # If autoresume is enabled, first check for existing checkpoints to load - self.autoresume = autoresume - if self.autoresume: + if autoresume: log.info('Searching for a previous checkpoint to autoresume') error_message = '' - if max_duration is None: - error_message += 'The `max_duration` must be specified on trainer.__init__ when autoresume is enabled. ' if save_folder is None: error_message += 'The `save_folder` must be specified when autoresume is enabled. ' if save_overwrite: @@ -2191,21 +2188,10 @@ def fit( # Reset Time if reset_time: - if self.autoresume: - raise ValueError( - 'Cannot specify `reset_time=True` when autoresume is enabled. Please instead ' - 'specify `load_ignore_keys` when constructing the Trainer, which will only ' - 'run on the initial load and not any subsequent autoresumptions.', - ) self.state.timestamp = Timestamp() # Max Duration if duration is not None: - if self.autoresume: - raise ValueError( - '`duration` cannot be specified when autoresume is enabled. Please instead ' - 'specify `max_duration` when constructing the Trainer.', - ) duration = ensure_time(duration, TimeUnit.EPOCH) if duration.unit == TimeUnit.SECOND: raise ValueError('Wall clock time not an allowed time unit.') diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index dc887fa5e2..d23b55875f 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -667,7 +667,6 @@ def get_trainer( max_duration: str = '2ep', latest_filename: str = 'latest-rank{rank}.pt', file_extension: str = '.pt', - use_scheduler: bool = True, **kwargs, ): if model is None: @@ -705,7 +704,7 @@ def get_trainer( save_filename='ep{epoch}' + file_extension, max_duration=max_duration, optimizers=optimizer, - schedulers=ExponentialScheduler(gamma=0.9) if use_scheduler else None, + schedulers=ExponentialScheduler(gamma=0.9), callbacks=callbacks, **kwargs, ) @@ -1213,43 +1212,24 @@ def test_load_weights_object_store(self, tmp_path): ) @pytest.mark.parametrize( - 'run_name,save_folder,save_overwrite,latest_filename,max_duration', + 'run_name,save_folder,save_overwrite,latest_filename', [ - [None, 'first', False, 'latest-rank{rank}.pt', '2ep'], - ['big-chungus', None, False, 'latest-rank{rank}.pt', '2ep'], - ['big-chungus', 'first', True, 'latest-rank{rank}.pt', '2ep'], - ['big-chungus', 'first', False, None, '2ep'], - ['big-chungus', 'first', False, 'latest-rank{rank}.pt', None], + [None, 'first', False, 'latest-rank{rank}.pt'], + ['big-chungus', None, False, 'latest-rank{rank}.pt'], + ['big-chungus', 'first', True, 'latest-rank{rank}.pt'], + ['big-chungus', 'first', False, None], ], ) - def test_autoresume_fail_init(self, run_name, save_folder, save_overwrite, latest_filename, max_duration): + def test_autoresume_fail(self, run_name, save_folder, save_overwrite, latest_filename): with pytest.raises(ValueError): self.get_trainer( latest_filename=latest_filename, save_overwrite=save_overwrite, save_folder=save_folder, run_name=run_name, - max_duration=max_duration, autoresume=True, - use_scheduler=False, ) - @pytest.mark.parametrize( - 'duration,reset_time', - [ - ['1ep', False], - [None, True], - ], - ) - def test_autoresume_fail_fit(self, duration: Optional[str], reset_time: bool): - trainer = self.get_trainer( - run_name='bigtrainer', - save_folder='first', - autoresume=True, - ) - with pytest.raises(ValueError): - trainer.fit(duration=duration, reset_time=reset_time) - def test_different_run_names(self): trainer_1 = self.get_trainer(