Skip to content

Commit

Permalink
Fix RNG key checking (#3623)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 21, 2024
1 parent 3c7fefb commit dfcbc45
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
17 changes: 16 additions & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,25 @@ def _get_write_mode(name: str) -> str:
raise ValueError(f'{name} does not end with a valid tarfile extension.')


def _is_rng_key(key: str, value: tuple) -> bool:
"""Check if the key is an RNG key.
We expect the RNG key to be of the form 'rng.{rank}.cuda|torch|python|numpy'.
This function ensures that we don't accidentally pick up other keys.
"""
starts_with_rng = key.startswith('rng')
ends_with_expected = key.endswith(('cuda', 'torch', 'python', 'numpy'))
three_parts = isinstance(value, tuple) and len(value) == 3
if starts_with_rng and ends_with_expected and three_parts:
return True

return False


def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
if 'rng' in field_name:
if _is_rng_key(field_name, field_value):
_, rng_rank_index, _ = field_value
rng_inds.append(rng_rank_index)
rng_inds = set(rng_inds)
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_COMPOSER_STATES_FILENAME,
PartialFilePath,
_ensure_valid_checkpoint,
_is_rng_key,
_write_checkpoint_file,
glob_filter,
)
Expand Down Expand Up @@ -130,6 +131,23 @@ def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0):
assert all(keys_in) or not any(keys_in)


@pytest.mark.parametrize(
'key,value,expected_result',
[
('rng.0.cuda', ('rng', '0', 'cuda'), True),
('rng.0.torch', ('rng', '0', 'torch'), True),
('rng.0.numpy', ('rng', '0', 'numpy'), True),
('rng.0.python', ('rng', '0', 'python'), True),
('rng.0', ('rng', '0'), False),
('test.test.rng', ('test', 'test', 'rng'), False),
('test.rng.test', ('test', 'rng', 'test'), False),
('test.notatuple.test', 0, False),
],
)
def test_is_rng_key(key: str, value: tuple, expected_result: bool):
assert _is_rng_key(key, value) == expected_result


@pytest.mark.parametrize(
'remove_field_paths,filter_params',
[
Expand Down

0 comments on commit dfcbc45

Please sign in to comment.