Skip to content

Commit

Permalink
create gloo pg for DCPSaver.restore() (#874)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #874

Reviewed By: galrotem

Differential Revision: D60408282

fbshipit-source-id: a0aaf117203ed6dd1f5c6e79a955a1bd8f855821
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Aug 1, 2024
1 parent 5c73bd5 commit 544a225
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
11 changes: 11 additions & 0 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,17 @@ def test_restore_allow_partial_loading(self, mock_dist_cp: MagicMock) -> None:
].allow_partial_load
self.assertFalse(allow_partial_load)

@patch("torch.distributed.destroy_process_group")
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
def test_gloo_pg_restore(
self, mock_dist_cp: MagicMock, mock_destroy_process_group: MagicMock
) -> None:
my_unit = DummyAutoUnit(module=nn.Linear(2, 3))
DistributedCheckpointSaver.restore(path="path/to/snapshot", unit=my_unit)
process_group = mock_dist_cp.load.call_args.kwargs["process_group"]
self.assertEqual(process_group, None)
mock_destroy_process_group.assert_not_called()


class DummyStatefulDataLoader:
def __init__(self, dataloader: DataLoader) -> None:
Expand Down
24 changes: 24 additions & 0 deletions tests/framework/callbacks/test_dcp_saver_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import shutil
import tempfile
import unittest
from unittest.mock import MagicMock, patch

import torch
from torch import distributed as dist, nn

from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
Expand All @@ -22,6 +24,28 @@


class DistributedCheckpointSaverGPUTest(unittest.TestCase):
@skip_if_not_distributed
@skip_if_not_gpu
def test_test_gloo_pg_restore(self) -> None:
spawn_multi_process(
1,
"nccl",
self._test_gloo_pg_restore,
)

@staticmethod
@patch("torch.distributed.destroy_process_group")
@patch("torchtnt.framework.callbacks.dcp_saver.dcp")
def _test_gloo_pg_restore(
mock_dist_cp: MagicMock, mock_destroy_process_group: MagicMock
) -> None:
tc = unittest.TestCase()
my_unit = DummyAutoUnit(module=nn.Linear(2, 3))
DistributedCheckpointSaver.restore(path="path/to/snapshot", unit=my_unit)
process_group = mock_dist_cp.load.call_args.kwargs["process_group"]
tc.assertEqual(dist.get_backend(process_group), dist.Backend.GLOO, None)
mock_destroy_process_group.assert_called_once()

@skip_if_not_distributed
@skip_if_not_gpu
def test_save_restore_fsdp(self) -> None:
Expand Down
23 changes: 22 additions & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import time
from concurrent.futures import Future
from datetime import timedelta
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
Expand Down Expand Up @@ -273,19 +274,39 @@ def restore(
) -> None:
"""Utility method to restore dcp checkpoint from a path."""

# use gloo pg if available
gloo_pg_created = False
if dist.is_initialized():
pg = dist.group.WORLD if process_group is None else process_group

if dist.get_backend(pg) != dist.Backend.GLOO:
rank_zero_info(
"Creating new gloo process group for loading checkpoint."
)
pg = dist.new_group(
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
)
gloo_pg_created = True
else:
pg = process_group

checkpoint_id = path

DistributedCheckpointSaver.restore_with_id(
checkpoint_id,
unit,
train_dataloader=train_dataloader,
process_group=process_group,
process_group=pg,
restore_options=restore_options,
knob_options=knob_options,
planner=planner,
storage_reader=storage_reader,
)

# destroy gloo pg if created, its sole purpose was for checkpoint restore
if gloo_pg_created:
dist.destroy_process_group(pg)

@staticmethod
def restore_with_id(
checkpoint_id: Union[int, str],
Expand Down

0 comments on commit 544a225

Please sign in to comment.