Skip to content

Commit

Permalink
Merge branch 'mvpatel2000/no-rng-dedup' of github.com:mvpatel2000/com…
Browse files Browse the repository at this point in the history
…poser into mvpatel2000/no-rng-dedup
  • Loading branch information
mvpatel2000 committed Feb 13, 2024
2 parents 487d222 + d0f9fd5 commit dec0056
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 49 deletions.
52 changes: 7 additions & 45 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def __init__(self, source_path: str, destination_path: str, object_store: Union[
super().__init__(destination_path)

def read_data(self, plan: LoadPlan, planner: LoadPlanner):
first_replica = self.device_mesh is None or self.device_mesh.get_local_rank(mesh_dim=0) == 0
# Download files if not using HSDP or if on first replica with HSDP enabled
first_replica = self.device_mesh is None or self.device_mesh.ndim == 1 or (
self.device_mesh.ndim >= 2 and self.device_mesh.get_local_rank(mesh_dim=0) == 0)

# 1. Download to the destination all files this rank needs if on first replica
if first_replica:
Expand Down Expand Up @@ -545,8 +547,6 @@ def load_sharded_checkpoint(

if state.fsdp_config is None:
raise ValueError('Loading a sharded checkpoint requires passing an FSDP config to Trainer.')
load_planner = state.fsdp_config['load_planner']
_validate_load_planner(load_planner)

# Check to make sure source_path is a directory.
if object_store is None:
Expand Down Expand Up @@ -603,14 +603,14 @@ def load_sharded_checkpoint(
dist_cp.load( # type: ignore
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
planner=state.fsdp_config['load_planner'],
no_dist=(not dist.is_initialized()),
)
else:
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
planner=state.fsdp_config['load_planner'],
no_dist=(not dist.is_initialized()),
)

Expand Down Expand Up @@ -822,40 +822,6 @@ def filter_func(state_dict: dict) -> None:
return filter_func


def _validate_save_planner(save_planner: Optional[Any]) -> None:
"""Checks that ``save_planner`` is an instance of a :class:`~torch.distributed.checkpoint.planner.SavePlanner`.
TODO(GRT-2456): Remove validation once we deprecate torch 1.13 and can use
type hints.
Raises:
ValueError: If ``save_planner`` is not a
:class:`~torch.distributed.checkpoint.planner.SavePlanner`.
"""
from torch.distributed.checkpoint.planner import SavePlanner

if save_planner is not None and not isinstance(save_planner, SavePlanner):
raise ValueError((f'save_planner {type(save_planner)} is not a '
'torch.distributed.checkpoint.planner.SavePlanner'))


def _validate_load_planner(load_planner: Optional[Any]) -> None:
"""Checks that ``load_planner`` is an instance of a :class:`~torch.distributed.checkpoint.planner.LoadPlanner`.
TODO(GRT-2456): Remove validation once we deprecate torch 1.13 and can use
type hints.
Raises:
ValueError: If ``load_planner`` is not a
:class:`~torch.distributed.checkpoint.planner.LoadPlanner`.
"""
from torch.distributed.checkpoint.planner import LoadPlanner

if load_planner is not None and not isinstance(load_planner, LoadPlanner):
raise ValueError((f'load_planner {type(load_planner)} is not a '
'torch.distributed.checkpoint.planner.LoadPlanner'))


def safe_torch_load(
composer_states_filepath: Union[Path, str],
map_location: str = 'cpu',
Expand Down Expand Up @@ -1065,10 +1031,6 @@ def _save_checkpoint(
elif state.fsdp_sharded_state_dict_enabled:
if state.fsdp_config is None:
raise ValueError('Saving a sharded checkpoint requires passing an FSDP config to Trainer.')
save_planner = state.fsdp_config['save_planner']
_validate_save_planner(save_planner)

import torch.distributed.checkpoint as dist_cp

log.debug(f'Saving sharded checkpoints to {save_filename}...')
process_group = None
Expand All @@ -1087,14 +1049,14 @@ def _save_checkpoint(
dist_cp.save( # type: ignore
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(dirname),
planner=save_planner,
planner=state.fsdp_config['save_planner'],
process_group=process_group,
)
else:
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(dirname),
planner=save_planner,
planner=state.fsdp_config['save_planner'],
process_group=process_group,
)
log.debug('Finished pytorch save state dict')
Expand Down
9 changes: 5 additions & 4 deletions tests/utils/eval_client/test_local_eval_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

from composer.utils import LocalEvalClient
from composer.utils import LocalEvalClient, dist
from tests.common.markers import world_size


Expand All @@ -29,10 +29,11 @@
)
@world_size(1, 2)
def test_local_invoke(code: str, result: str, language: str, world_size: int, tmp_path: str):
"""Test invocation function for LocalEvalClient with code that succeeds, fails compilation, times out, and is incorrect in C, C++, Python, JS.
"""Test invocation function for LocalEvalClient.
Code can succeed, fail compilation, time out, or be incorrect in C, C++, Python, JS.
"""
import os
os.makedirs(os.path.dirname(tmp_path), exist_ok=True)
dist.barrier() # Ensure all processes are ready to run the test as invoke doesn't use dist
eval_client = LocalEvalClient()
input = '(1,)' if language == 'python' else '1'
assert eval_client.invoke([[[{
Expand Down

0 comments on commit dec0056

Please sign in to comment.