Skip to content

Commit

Permalink
Require dist initialized if world size > 1 in checkpointing (#930)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #930

Differential Revision: D64342419
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Oct 15, 2024
1 parent 716243c commit 10b0c68
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 2 deletions.
14 changes: 14 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ class BaseCheckpointerTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def tearDown(self) -> None:
os.environ.pop('WORLD_SIZE', None)

def test_save_every_n_train_steps(self) -> None:
input_dim = 2
dataset_len = 10
Expand Down Expand Up @@ -1199,6 +1202,17 @@ def _test_directory_path_synced() -> None:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

def test_dist_not_initialized(self) -> None:
"""
Tests that BaseCheckpointSaver cannot be initialized without dist being initialized
if world size > 1
"""
os.environ["WORLD_SIZE"] = "2"
with self.assertRaisesRegex(
RuntimeError, "Running in a distributed environment"
):
BaseCheckpointSaver("foo")


class MyValLossUnit(TrainUnit[Batch]):
def __init__(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@


class DistributedCheckpointSaverTest(unittest.TestCase):
def tearDown(self) -> None:
os.environ.pop("WORLD_SIZE", None)

def test_save_restore(self) -> None:
input_dim = 2
dataset_len = 10
Expand Down
3 changes: 3 additions & 0 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@


class TorchSnapshotSaverTest(unittest.TestCase):
def tearDown(self) -> None:
os.environ.pop("WORLD_SIZE", None)

def test_save_restore(self) -> None:
input_dim = 2
dataset_len = 10
Expand Down
12 changes: 11 additions & 1 deletion torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
MetricData,
Phase,
)
from torchtnt.utils.distributed import PGWrapper
from torchtnt.utils.distributed import get_world_size, PGWrapper
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn

logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -55,6 +55,9 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
This will be used by this base class to ensure the integrity of the checkpoint. This is a list because some checkpointers may allow more than one valid
``metadata_fnames``, depending on storage or optimization configurations.
If running in a distributed environment, the default process group should be initialized prior to instantiating this Callback. This is done automatically if
using `AutoUnit`, which should be instantiated first.
Args:
dirpath: Parent directory to save checkpoints to.
save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated.
Expand Down Expand Up @@ -96,6 +99,13 @@ def __init__(
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> None:
if get_world_size() > 1 and not dist.is_initialized():
raise RuntimeError(
"Running in a distributed environment without default process group initialized. "
"Call `torch.distributed.init_process_group` before initializing this callback. "
"Using `AutoUnit` will do this automatically."
)

if save_every_n_train_steps is not None and save_every_n_train_steps <= 0:
raise ValueError(
f"Invalid value passed for save_every_n_train_steps. Expected to receive either None or positive number, but received {save_every_n_train_steps}"
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class DistributedCheckpointSaver(BaseCheckpointer):
knob_options: Additional keyword options for StorageWriter. <https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.StorageWriter/>
Note:
If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion.
If torch.distributed is available, there should be a process group is initialized. In this case DCP assumes the intention is to save/load checkpoints in distributed fashion.
Additionally, a gloo process group must be initialized for async_checkpoint. For workloads that require nccl, the recommended initialization is 'cpu:gloo,cuda:nccl'
Note:
Expand Down

0 comments on commit 10b0c68

Please sign in to comment.