Skip to content

Commit

Permalink
Merge branch 'main' of github.com-mvpatel2000:mosaicml/composer
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Sep 23, 2024
2 parents 29e6be1 + 7597ab6 commit 265ba14
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
17 changes: 16 additions & 1 deletion composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,25 @@ def _get_write_mode(name: str) -> str:
raise ValueError(f'{name} does not end with a valid tarfile extension.')


def _is_rng_key(key: str, value: tuple) -> bool:
"""Check if the key is an RNG key.
We expect the RNG key to be of the form 'rng.{rank}.cuda|torch|python|numpy'.
This function ensures that we don't accidentally pick up other keys.
"""
starts_with_rng = key.startswith('rng')
ends_with_expected = key.endswith(('cuda', 'torch', 'python', 'numpy'))
three_parts = isinstance(value, tuple) and len(value) == 3
if starts_with_rng and ends_with_expected and three_parts:
return True

return False


def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
if 'rng' in field_name:
if _is_rng_key(field_name, field_value):
_, rng_rank_index, _ = field_value
rng_inds.append(rng_rank_index)
rng_inds = set(rng_inds)
Expand Down
4 changes: 2 additions & 2 deletions tests/loggers/test_wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,10 +269,10 @@ def test_wandb_log_metrics(test_wandb_logger):
eval_metrics_cross_entropy_count = all_run_text.count('metrics/eval/CrossEntropy')
train_loss_count = all_run_text.count('loss/train/total')

expected_number_train_loss_count = (dataset_size / batch_size) + 1 # wandb includes it in the file one extra time
expected_number_train_loss_count = (dataset_size / batch_size) * 2 # wandb includes it twice per step
expected_number_train_metrics_count = (
dataset_size / batch_size
) + 2 # wandb includes it in the file two extra times
) * 2 + 2 # wandb includes it twice per step plus two extra times
expected_number_eval_metrics_count = 2 # wandb includes it in the file twice
assert train_metrics_accuracy_count == expected_number_train_metrics_count
assert train_loss_count == expected_number_train_loss_count
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_COMPOSER_STATES_FILENAME,
PartialFilePath,
_ensure_valid_checkpoint,
_is_rng_key,
_write_checkpoint_file,
glob_filter,
)
Expand Down Expand Up @@ -130,6 +131,23 @@ def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0):
assert all(keys_in) or not any(keys_in)


@pytest.mark.parametrize(
'key,value,expected_result',
[
('rng.0.cuda', ('rng', '0', 'cuda'), True),
('rng.0.torch', ('rng', '0', 'torch'), True),
('rng.0.numpy', ('rng', '0', 'numpy'), True),
('rng.0.python', ('rng', '0', 'python'), True),
('rng.0', ('rng', '0'), False),
('test.test.rng', ('test', 'test', 'rng'), False),
('test.rng.test', ('test', 'rng', 'test'), False),
('test.notatuple.test', 0, False),
],
)
def test_is_rng_key(key: str, value: tuple, expected_result: bool):
assert _is_rng_key(key, value) == expected_result


@pytest.mark.parametrize(
'remove_field_paths,filter_params',
[
Expand Down

0 comments on commit 265ba14

Please sign in to comment.