From 7597ab6e873e4c77caa5cb87ec06282123af1e60 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 20 Sep 2024 20:22:59 -0700 Subject: [PATCH 1/6] Fix RNG key checking (#3623) --- composer/utils/checkpoint.py | 17 ++++++++++++++++- tests/trainer/test_checkpoint.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index c6f5af15ca..b966c918c5 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -148,10 +148,25 @@ def _get_write_mode(name: str) -> str: raise ValueError(f'{name} does not end with a valid tarfile extension.') +def _is_rng_key(key: str, value: tuple) -> bool: + """Check if the key is an RNG key. + + We expect the RNG key to be of the form 'rng.{rank}.cuda|torch|python|numpy'. + This function ensures that we don't accidentally pick up other keys. + """ + starts_with_rng = key.startswith('rng') + ends_with_expected = key.endswith(('cuda', 'torch', 'python', 'numpy')) + three_parts = isinstance(value, tuple) and len(value) == 3 + if starts_with_rng and ends_with_expected and three_parts: + return True + + return False + + def _get_num_ranks_that_saved_rng(metadata: Metadata): rng_inds = [] for field_name, field_value in metadata.planner_data.items(): - if 'rng' in field_name: + if _is_rng_key(field_name, field_value): _, rng_rank_index, _ = field_value rng_inds.append(rng_rank_index) rng_inds = set(rng_inds) diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 82629d245b..c2e4929535 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -35,6 +35,7 @@ _COMPOSER_STATES_FILENAME, PartialFilePath, _ensure_valid_checkpoint, + _is_rng_key, _write_checkpoint_file, glob_filter, ) @@ -130,6 +131,23 @@ def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0): assert all(keys_in) or not any(keys_in) +@pytest.mark.parametrize( + 'key,value,expected_result', + [ + ('rng.0.cuda', ('rng', '0', 'cuda'), True), + ('rng.0.torch', ('rng', '0', 'torch'), True), + ('rng.0.numpy', ('rng', '0', 'numpy'), True), + ('rng.0.python', ('rng', '0', 'python'), True), + ('rng.0', ('rng', '0'), False), + ('test.test.rng', ('test', 'test', 'rng'), False), + ('test.rng.test', ('test', 'rng', 'test'), False), + ('test.notatuple.test', 0, False), + ], +) +def test_is_rng_key(key: str, value: tuple, expected_result: bool): + assert _is_rng_key(key, value) == expected_result + + @pytest.mark.parametrize( 'remove_field_paths,filter_params', [ From 7e12e1f8bec48b860663e69bf075ec211d2c0557 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 10:03:37 -0700 Subject: [PATCH 2/6] Update datasets requirement from <3,>=2.4 to >=2.4,<4 (#3626) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6cc65702a7..2f8d41de48 100644 --- a/setup.py +++ b/setup.py @@ -181,7 +181,7 @@ def package_files(prefix: str, directory: str, extension: str): extra_deps['nlp'] = [ 'transformers>=4.11,!=4.34.0,<4.45', - 'datasets>=2.4,<3', + 'datasets>=2.4,<4', 'huggingface-hub>=0.21.2,<0.25', ] From 80191b84bcce9f4ffe500259e969ed1678b93225 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 23 Sep 2024 10:03:56 -0700 Subject: [PATCH 3/6] Disable exceptions for MosaicML Logger (#3627) --- composer/loggers/mosaicml_logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/loggers/mosaicml_logger.py b/composer/loggers/mosaicml_logger.py index 1fcb612ed5..2094643fb1 100644 --- a/composer/loggers/mosaicml_logger.py +++ b/composer/loggers/mosaicml_logger.py @@ -62,14 +62,14 @@ class MosaicMLLogger(LoggerDestination): Example 2: ``ignore_keys = ["wall_clock/*"]`` would ignore all wall clock metrics. (default: ``None``) - ignore_exceptions: Flag to disable logging exceptions. Defaults to False. + ignore_exceptions: Flag to disable logging exceptions. Defaults to True. """ def __init__( self, log_interval: int = 60, ignore_keys: Optional[list[str]] = None, - ignore_exceptions: bool = False, + ignore_exceptions: bool = True, ) -> None: self.log_interval = log_interval self.ignore_keys = ignore_keys From d2e1d5e64d6067c3009b380d276142441acc7db4 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 23 Sep 2024 11:15:31 -0700 Subject: [PATCH 4/6] Fix CPU dailies (#3628) --- tests/callbacks/callback_settings.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/callbacks/callback_settings.py b/tests/callbacks/callback_settings.py index 279f47a695..4b3ef478a2 100644 --- a/tests/callbacks/callback_settings.py +++ b/tests/callbacks/callback_settings.py @@ -146,6 +146,11 @@ NeptuneLogger: { 'mode': 'debug', }, + WandBLogger: { + 'init_kwargs': { + 'mode': 'offline', + }, + }, composer.profiler.Profiler: { 'trace_handlers': [MagicMock()], 'schedule': composer.profiler.cyclic_schedule(), From 4cdc2cdd311c163ff801cb43d2a90f199f0d0d77 Mon Sep 17 00:00:00 2001 From: bigning Date: Mon, 23 Sep 2024 13:04:27 -0700 Subject: [PATCH 5/6] fix 2.4.1ckpt (#3629) --- composer/trainer/_patch_pytorch.py | 6 +-- composer/utils/checkpoint.py | 69 +++++++++++++----------------- 2 files changed, 31 insertions(+), 44 deletions(-) diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index fcca94d73a..77c4d733f7 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -945,8 +945,7 @@ def unshard_with_sync(self): if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse( torch.__version__, -) < version.parse('2.4.1'): - # 2.4.0 only patch +) < version.parse('2.4.2'): # PyTorch issue: https://github.com/pytorch/pytorch/issues/133923 from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE from typing import Mapping, Collection @@ -1003,9 +1002,6 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: for key, value in state_dict.items(): _traverse_obj((str(key),), value) -if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse( - torch.__version__, -) < version.parse('2.4.2'): # Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks from torch.distributed.fsdp._flat_param import FlatParamHandle original_unshard = FlatParamHandle.unshard diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index b966c918c5..11e79bff10 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -623,50 +623,41 @@ def dist_cp_load( load_planner: Optional[LoadPlanner] = None, ): if version.parse(torch.__version__) >= version.parse('2.4.0'): - if version.parse(torch.__version__) < version.parse('2.4.1'): - # PyTorch 2.4.0 - from torch.distributed.checkpoint.utils import CheckpointException - try: - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - except CheckpointException as e: - checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata - if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: - # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. - # Torch issue: https://github.com/pytorch/pytorch/issues/133923. - # We override the traverse_state_dict so that the load planner could - # use the old way of flattening the state dict - log.debug('Trying to load checkpointing saved before torch 2.4') - - import torch.distributed.checkpoint._nested_dict as nested_dict - import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util - from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 - - from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse - - nested_dict.traverse_state_dict = backward_compatible_traverse - sharded_tensor_util.traverse_state_dict = backward_compatible_traverse - - dist_cp.load( - state_dict=state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) - # Revert the override - nested_dict.traverse_state_dict = traverse_2_4_0 - sharded_tensor_util.traverse_state_dict = traverse_2_4_0 - else: - raise e - else: - # PyTorch 2.4.1 + from torch.distributed.checkpoint.utils import CheckpointException + try: dist_cp.load( state_dict=state_dict, storage_reader=storage_reader, planner=load_planner, ) + except CheckpointException as e: + checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata + if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata: + # Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility. + # Torch issue: https://github.com/pytorch/pytorch/issues/133923. + # We override the traverse_state_dict so that the load planner could + # use the old way of flattening the state dict + log.debug('Trying to load checkpointing saved before torch 2.4') + + import torch.distributed.checkpoint._nested_dict as nested_dict + import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util + from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0 + + from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse + + nested_dict.traverse_state_dict = backward_compatible_traverse + sharded_tensor_util.traverse_state_dict = backward_compatible_traverse + + dist_cp.load( + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) + # Revert the override + nested_dict.traverse_state_dict = traverse_2_4_0 + sharded_tensor_util.traverse_state_dict = traverse_2_4_0 + else: + raise e else: dist_cp.load_state_dict( state_dict=state_dict, From 17304a0b54daa8f982107cdf138afb3b0e0e6da8 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 23 Sep 2024 16:46:29 -0400 Subject: [PATCH 6/6] More checkpoint debug logs (#3632) --- composer/trainer/trainer.py | 3 +++ composer/utils/checkpoint.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 815aa50001..32a9426523 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1902,13 +1902,16 @@ def __init__( log.info('No previous autoresume checkpoint found') # Actually load the checkpoint from potentially updated arguments if load_path is not None: + log.info(f'Loading checkpoint from {load_path}') if load_object_store is None: load_object_store = maybe_create_object_store_from_uri(load_path) + log.debug(f'Created object store from load path: {load_object_store}') if isinstance(load_object_store, WandBLogger): import wandb if wandb.run is None: load_object_store.init(self.state, self.logger) _, _, parsed_load_path = parse_uri(load_path) + log.debug(f'Parsed load path: {parsed_load_path}') self._rng_state = checkpoint.load_checkpoint( state=self.state, diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 11e79bff10..ccefee4e60 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -417,6 +417,7 @@ def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, if source_path.endswith('.symlink') or os.path.islink(source_path): source_path = extract_path_from_symlink(source_path, object_store=object_store) metadata_path = str(Path(source_path) / Path('.metadata')) + log.debug(f'Checking if checkpoint is legacy sharded by checking for metadata file at {metadata_path}.') if object_store is None: return not os.path.exists(metadata_path) else: @@ -531,6 +532,7 @@ def load_checkpoint( :attr:`load_weights_only` is not None. Otherwise, None. """ path = partial_format(path, run_name=state.run_name) + log.debug(f'Loading checkpoint from formatted path: {path}') using_legacy_sharded = False if state.fsdp_sharded_state_dict_enabled: assert object_store is None or isinstance( @@ -538,6 +540,7 @@ def load_checkpoint( ObjectStore, ), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore' using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path) + log.info(f'Using legacy sharded checkpoint: {using_legacy_sharded}') if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded: rng_state_dicts = load_sharded_checkpoint(