Skip to content

Commit

Permalink
Use Gloo PG if available for both restore and restore_with_id methods (
Browse files Browse the repository at this point in the history
…#897)

Summary:
Pull Request resolved: #897

Use Gloo PG if available for both restore and restore_with_id methods.
This diff moves the logic to restore_with_id which gets called by the restore method. This will ensure that it takes effect for both the code paths.

Reviewed By: JKSenthil

Differential Revision: D62539308

fbshipit-source-id: bb37c2ce0e33027967c7ef5727ca09c3ec491fc6
  • Loading branch information
saumishr authored and facebook-github-bot committed Sep 12, 2024
1 parent 665dd50 commit 33b98f4
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 24 deletions.
73 changes: 73 additions & 0 deletions tests/framework/callbacks/test_dcp_saver_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,30 @@ def _test_gloo_pg_restore(
tc.assertEqual(dist.get_backend(process_group), dist.Backend.GLOO, None)
mock_destroy_process_group.assert_called_once()

@skip_if_not_distributed
@skip_if_not_gpu
def test_test_gloo_pg_restore_wth_id(self) -> None:
spawn_multi_process(
1,
"nccl",
self._test_gloo_pg_restore,
)

@staticmethod
@patch("torch.distributed.destroy_process_group")
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
def _test_gloo_pg_restore_with_id(
mock_dist_cp: MagicMock, mock_destroy_process_group: MagicMock
) -> None:
tc = unittest.TestCase()
my_unit = DummyAutoUnit(module=nn.Linear(2, 3))
DistributedCheckpointSaver.restore_with_id(
checkpoint_id="path/to/snapshot", unit=my_unit
)
process_group = mock_dist_cp.load.call_args.kwargs["process_group"]
tc.assertEqual(dist.get_backend(process_group), dist.Backend.GLOO, None)
mock_destroy_process_group.assert_called_once()

@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp(self) -> None:
Expand Down Expand Up @@ -94,3 +118,52 @@ def _save_restore_fsdp() -> None:
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp_with_id(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp_with_id,
)

@staticmethod
def _save_restore_fsdp_with_id() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

dcp_cb = DistributedCheckpointSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
)
temp_dir = dcp_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[dcp_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_train_step_10")
dcp_cb.restore_with_id(checkpoint_id=ckpt_path, unit=my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory
49 changes: 25 additions & 24 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,39 +230,19 @@ def restore(
) -> None:
"""Utility method to restore dcp checkpoint from a path."""

# use gloo pg if available
gloo_pg_created = False
if dist.is_initialized():
pg = dist.group.WORLD if process_group is None else process_group

if dist.get_backend(pg) != dist.Backend.GLOO:
rank_zero_info(
"Creating new gloo process group for loading checkpoint."
)
pg = dist.new_group(
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
)
gloo_pg_created = True
else:
pg = process_group

checkpoint_id = path

DistributedCheckpointSaver.restore_with_id(
checkpoint_id,
unit,
train_dataloader=train_dataloader,
process_group=pg,
process_group=process_group,
restore_options=restore_options,
knob_options=knob_options,
planner=planner,
storage_reader=storage_reader,
)

# destroy gloo pg if created, its sole purpose was for checkpoint restore
if gloo_pg_created:
dist.destroy_process_group(pg)

@staticmethod
def restore_with_id(
checkpoint_id: Union[int, str],
Expand All @@ -284,15 +264,32 @@ def restore_with_id(
checkpoint_id: Checkpoint id. It can be the path of the snapshot to restore.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) Note:
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
If not Gloo, a Gloo process group is created.
Note: If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
restore_options: Controls what to filter when restoring the state.
knob_options: Additional keyword options for StorageWriter and StorageReader
planner: Instance of LoadPlanner. If this is not specificed, the default planner will be used. (Default: ``None``)
storage_reader: Instance of StorageReader used to perform reads. If this is not specified, it will automatically infer
the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``)
"""

# use gloo pg if available
gloo_pg_created = False
if dist.is_initialized():
pg = dist.group.WORLD if process_group is None else process_group

if dist.get_backend(pg) != dist.Backend.GLOO:
rank_zero_info(
"Creating new gloo process group for loading checkpoint."
)
pg = dist.new_group(
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
)
gloo_pg_created = True
else:
pg = process_group

restore_options = restore_options or RestoreOptions()
app_state = _prepare_app_state_for_restore(unit, restore_options)
checkpoint_id = str(checkpoint_id)
Expand Down Expand Up @@ -340,13 +337,17 @@ def restore_with_id(
checkpoint_id=checkpoint_id,
storage_reader=storage_reader,
planner=planner,
process_group=process_group,
process_group=pg,
)

rank_zero_info(
f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger
)

# destroy gloo pg if created, its sole purpose was for checkpoint restore
if gloo_pg_created:
dist.destroy_process_group(pg)

def _generate_checkpoint_and_upkeep(
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
) -> bool:
Expand Down

0 comments on commit 33b98f4

Please sign in to comment.