Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 21, 2024
1 parent a85b8ea commit daf982a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,13 @@ def _get_write_mode(name: str) -> str:


def _is_rng_key(key: str, value: tuple) -> bool:
"""Check if the key is an RNG key."""
"""Check if the key is an RNG key.
We expect the RNG key to be of the form 'rng.{rank}.cuda|torch|python|numpy'.
This function ensures that we don't accidentally pick up other keys.
"""
starts_with_rng = key.startswith('rng')
ends_with_expected = key.endswith('cuda') or key.endswith('torch') or key.endswith(
'python',
) or key.endswith('numpy')
ends_with_expected = key.endswith(('cuda', 'torch', 'python', 'numpy'))
three_parts = isinstance(value, tuple) and len(value) == 3
if starts_with_rng and ends_with_expected and three_parts:
return True
Expand Down

0 comments on commit daf982a

Please sign in to comment.