diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index e43ad50531..aecc3540ce 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -122,6 +122,9 @@ class BaseCheckpointerTest(unittest.TestCase): cuda_available: bool = torch.cuda.is_available() distributed_available: bool = torch.distributed.is_available() + def tearDown(self) -> None: + os.environ.pop('WORLD_SIZE', None) + def test_save_every_n_train_steps(self) -> None: input_dim = 2 dataset_len = 10 @@ -1199,6 +1202,17 @@ def _test_directory_path_synced() -> None: if get_global_rank() == 0: shutil.rmtree(temp_dir) # delete temp directory + def test_dist_not_initialized(self) -> None: + """ + Tests that BaseCheckpointSaver cannot be initialized without dist being initialized + if world size > 1 + """ + os.environ["WORLD_SIZE"] = "2" + with self.assertRaisesRegex( + RuntimeError, "Running in a distributed environment" + ): + BaseCheckpointSaver("foo") + class MyValLossUnit(TrainUnit[Batch]): def __init__(self) -> None: diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index 7b8fe71414..62f73a29fc 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -54,6 +54,9 @@ class DistributedCheckpointSaverTest(unittest.TestCase): + def tearDown(self) -> None: + os.environ.pop("WORLD_SIZE", None) + def test_save_restore(self) -> None: input_dim = 2 dataset_len = 10 diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index ddf617bb89..b47f3db2f8 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -40,6 +40,9 @@ class TorchSnapshotSaverTest(unittest.TestCase): + def tearDown(self) -> None: + os.environ.pop("WORLD_SIZE", None) + def test_save_restore(self) -> None: input_dim = 2 dataset_len = 10 diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index ddf02c420e..81a065ccf9 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -36,7 +36,7 @@ MetricData, Phase, ) -from torchtnt.utils.distributed import PGWrapper +from torchtnt.utils.distributed import get_world_size, PGWrapper from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn logger: logging.Logger = logging.getLogger(__name__) @@ -55,6 +55,9 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta): This will be used by this base class to ensure the integrity of the checkpoint. This is a list because some checkpointers may allow more than one valid ``metadata_fnames``, depending on storage or optimization configurations. + If running in a distributed environment, the default process group should be initialized prior to instantiating this Callback. This is done automatically if + using `AutoUnit`, which should be instantiated first. + Args: dirpath: Parent directory to save checkpoints to. save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated. @@ -96,6 +99,13 @@ def __init__( best_checkpoint_config: Optional[BestCheckpointConfig] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> None: + if get_world_size() > 1 and not dist.is_initialized(): + raise RuntimeError( + "Running in a distributed environment without default process group initialized. " + "Call `torch.distributed.init_process_group` before initializing this callback. " + "Using `AutoUnit` will do this automatically." + ) + if save_every_n_train_steps is not None and save_every_n_train_steps <= 0: raise ValueError( f"Invalid value passed for save_every_n_train_steps. Expected to receive either None or positive number, but received {save_every_n_train_steps}" diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index f4c845d1eb..cdef922958 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -95,7 +95,7 @@ class DistributedCheckpointSaver(BaseCheckpointer): knob_options: Additional keyword options for StorageWriter. Note: - If torch.distributed is available and a process group is initialized, dcp assumes the intention is to save/load checkpoints in distributed fashion. + If torch.distributed is available, there should be a process group is initialized. In this case DCP assumes the intention is to save/load checkpoints in distributed fashion. Additionally, a gloo process group must be initialized for async_checkpoint. For workloads that require nccl, the recommended initialization is 'cpu:gloo,cuda:nccl' Note: