Skip to content

Commit

Permalink
Merge branch 'mosaicml:main' into tp-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 23, 2024
2 parents 895f08e + 17304a0 commit 9df88cb
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 48 deletions.
4 changes: 2 additions & 2 deletions composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ class MosaicMLLogger(LoggerDestination):
Example 2: ``ignore_keys = ["wall_clock/*"]`` would ignore all wall clock metrics.
(default: ``None``)
ignore_exceptions: Flag to disable logging exceptions. Defaults to False.
ignore_exceptions: Flag to disable logging exceptions. Defaults to True.
"""

def __init__(
self,
log_interval: int = 60,
ignore_keys: Optional[list[str]] = None,
ignore_exceptions: bool = False,
ignore_exceptions: bool = True,
) -> None:
self.log_interval = log_interval
self.ignore_keys = ignore_keys
Expand Down
6 changes: 1 addition & 5 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,8 +945,7 @@ def unshard_with_sync(self):

if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.1'):
# 2.4.0 only patch
) < version.parse('2.4.2'):
# PyTorch issue: https://github.com/pytorch/pytorch/issues/133923
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from typing import Mapping, Collection
Expand Down Expand Up @@ -1003,9 +1002,6 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
for key, value in state_dict.items():
_traverse_obj((str(key),), value)

if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.2'):
# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
from torch.distributed.fsdp._flat_param import FlatParamHandle
original_unshard = FlatParamHandle.unshard
Expand Down
3 changes: 3 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,13 +1902,16 @@ def __init__(
log.info('No previous autoresume checkpoint found')
# Actually load the checkpoint from potentially updated arguments
if load_path is not None:
log.info(f'Loading checkpoint from {load_path}')
if load_object_store is None:
load_object_store = maybe_create_object_store_from_uri(load_path)
log.debug(f'Created object store from load path: {load_object_store}')
if isinstance(load_object_store, WandBLogger):
import wandb
if wandb.run is None:
load_object_store.init(self.state, self.logger)
_, _, parsed_load_path = parse_uri(load_path)
log.debug(f'Parsed load path: {parsed_load_path}')

self._rng_state = checkpoint.load_checkpoint(
state=self.state,
Expand Down
89 changes: 49 additions & 40 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,25 @@ def _get_write_mode(name: str) -> str:
raise ValueError(f'{name} does not end with a valid tarfile extension.')


def _is_rng_key(key: str, value: tuple) -> bool:
"""Check if the key is an RNG key.
We expect the RNG key to be of the form 'rng.{rank}.cuda|torch|python|numpy'.
This function ensures that we don't accidentally pick up other keys.
"""
starts_with_rng = key.startswith('rng')
ends_with_expected = key.endswith(('cuda', 'torch', 'python', 'numpy'))
three_parts = isinstance(value, tuple) and len(value) == 3
if starts_with_rng and ends_with_expected and three_parts:
return True

return False


def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
if 'rng' in field_name:
if _is_rng_key(field_name, field_value):
_, rng_rank_index, _ = field_value
rng_inds.append(rng_rank_index)
rng_inds = set(rng_inds)
Expand Down Expand Up @@ -402,6 +417,7 @@ def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination,
if source_path.endswith('.symlink') or os.path.islink(source_path):
source_path = extract_path_from_symlink(source_path, object_store=object_store)
metadata_path = str(Path(source_path) / Path('.metadata'))
log.debug(f'Checking if checkpoint is legacy sharded by checking for metadata file at {metadata_path}.')
if object_store is None:
return not os.path.exists(metadata_path)
else:
Expand Down Expand Up @@ -516,13 +532,15 @@ def load_checkpoint(
:attr:`load_weights_only` is not None. Otherwise, None.
"""
path = partial_format(path, run_name=state.run_name)
log.debug(f'Loading checkpoint from formatted path: {path}')
using_legacy_sharded = False
if state.fsdp_sharded_state_dict_enabled:
assert object_store is None or isinstance(
object_store,
ObjectStore,
), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore'
using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path)
log.info(f'Using legacy sharded checkpoint: {using_legacy_sharded}')

if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded:
rng_state_dicts = load_sharded_checkpoint(
Expand Down Expand Up @@ -608,50 +626,41 @@ def dist_cp_load(
load_planner: Optional[LoadPlanner] = None,
):
if version.parse(torch.__version__) >= version.parse('2.4.0'):
if version.parse(torch.__version__) < version.parse('2.4.1'):
# PyTorch 2.4.0
from torch.distributed.checkpoint.utils import CheckpointException
try:
dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
except CheckpointException as e:
checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata
if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata:
# Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
# Torch issue: https://github.com/pytorch/pytorch/issues/133923.
# We override the traverse_state_dict so that the load planner could
# use the old way of flattening the state dict
log.debug('Trying to load checkpointing saved before torch 2.4')

import torch.distributed.checkpoint._nested_dict as nested_dict
import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util
from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0

from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse

nested_dict.traverse_state_dict = backward_compatible_traverse
sharded_tensor_util.traverse_state_dict = backward_compatible_traverse

dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
# Revert the override
nested_dict.traverse_state_dict = traverse_2_4_0
sharded_tensor_util.traverse_state_dict = traverse_2_4_0
else:
raise e
else:
# PyTorch 2.4.1
from torch.distributed.checkpoint.utils import CheckpointException
try:
dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
except CheckpointException as e:
checkpoint_metadata = storage_reader.read_metadata().state_dict_metadata
if 'state.metadata' in checkpoint_metadata and 'state.metadata.composer_env_info.composer_version' not in checkpoint_metadata:
# Torch 2.4 changed the way how state dict is flattened. It broke backward compatibility.
# Torch issue: https://github.com/pytorch/pytorch/issues/133923.
# We override the traverse_state_dict so that the load planner could
# use the old way of flattening the state dict
log.debug('Trying to load checkpointing saved before torch 2.4')

import torch.distributed.checkpoint._nested_dict as nested_dict
import torch.distributed.checkpoint._sharded_tensor_utils as sharded_tensor_util
from torch.distributed.checkpoint._traverse import traverse_state_dict as traverse_2_4_0

from composer.trainer._patch_pytorch import traverse_state_dict as backward_compatible_traverse

nested_dict.traverse_state_dict = backward_compatible_traverse
sharded_tensor_util.traverse_state_dict = backward_compatible_traverse

dist_cp.load(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
)
# Revert the override
nested_dict.traverse_state_dict = traverse_2_4_0
sharded_tensor_util.traverse_state_dict = traverse_2_4_0
else:
raise e
else:
dist_cp.load_state_dict(
state_dict=state_dict,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def package_files(prefix: str, directory: str, extension: str):

extra_deps['nlp'] = [
'transformers>=4.11,!=4.34.0,<4.45',
'datasets>=2.4,<3',
'datasets>=2.4,<4',
'huggingface-hub>=0.21.2,<0.25',
]

Expand Down
5 changes: 5 additions & 0 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@
NeptuneLogger: {
'mode': 'debug',
},
WandBLogger: {
'init_kwargs': {
'mode': 'offline',
},
},
composer.profiler.Profiler: {
'trace_handlers': [MagicMock()],
'schedule': composer.profiler.cyclic_schedule(),
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_COMPOSER_STATES_FILENAME,
PartialFilePath,
_ensure_valid_checkpoint,
_is_rng_key,
_write_checkpoint_file,
glob_filter,
)
Expand Down Expand Up @@ -130,6 +131,23 @@ def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0):
assert all(keys_in) or not any(keys_in)


@pytest.mark.parametrize(
'key,value,expected_result',
[
('rng.0.cuda', ('rng', '0', 'cuda'), True),
('rng.0.torch', ('rng', '0', 'torch'), True),
('rng.0.numpy', ('rng', '0', 'numpy'), True),
('rng.0.python', ('rng', '0', 'python'), True),
('rng.0', ('rng', '0'), False),
('test.test.rng', ('test', 'test', 'rng'), False),
('test.rng.test', ('test', 'rng', 'test'), False),
('test.notatuple.test', 0, False),
],
)
def test_is_rng_key(key: str, value: tuple, expected_result: bool):
assert _is_rng_key(key, value) == expected_result


@pytest.mark.parametrize(
'remove_field_paths,filter_params',
[
Expand Down

0 comments on commit 9df88cb

Please sign in to comment.