diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index 9b144f03b1..bc782c477d 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -6,6 +6,7 @@ # pyre-strict import os +import pickle import shutil import tempfile import unittest @@ -173,6 +174,21 @@ def test_compare_by_optimality(self) -> None: self.assertTrue(smaller.more_optimal_than(larger, mode="min")) self.assertFalse(larger.more_optimal_than(smaller, mode="min")) + def test_pickling(self) -> None: + for path in ( + "foo/epoch_0_step_1", + "file://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98", + ): + ckpt = CheckpointPath.from_str(path) + + pickled = pickle.dumps(ckpt) + + # Don't test equality because of custom protocol + self.assertTrue(path in str(pickled)) + + unpickled = pickle.loads(pickled) + self.assertEqual(unpickled, ckpt) + class CheckpointUtilsTest(unittest.TestCase): @staticmethod diff --git a/tests/utils/test_checkpoint_gpu.py b/tests/utils/test_checkpoint_gpu.py new file mode 100644 index 0000000000..818aee4165 --- /dev/null +++ b/tests/utils/test_checkpoint_gpu.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os +import shutil +import tempfile +import unittest + +import torch.distributed as dist +from torchtnt.utils import init_from_env +from torchtnt.utils.checkpoint import get_checkpoint_dirpaths +from torchtnt.utils.distributed import get_global_rank, spawn_multi_process +from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu + + +class TestCheckpointUtilsGPU(unittest.TestCase): + + @skip_if_not_distributed + @skip_if_not_gpu + def test_get_checkpoint_dirpaths_distributed(self) -> None: + spawn_multi_process( + 2, + "nccl", + self._test_get_checkpoint_dirpaths, + ) + + @staticmethod + def _test_get_checkpoint_dirpaths() -> None: + """ + Tests retrieving checkpoint directories from a given root directory + using NCCL on GPUs with custom state for pickling. + """ + init_from_env() + paths = [ + "epoch_0_step_10", + "epoch_1_step_10_val_loss=10.5", + "epoch_2_step_10", + "epoch_0_step_5", + "epoch_0_step_6_acc=0.03", + "epoch_0_step_3", + ] + + if get_global_rank() == 0: + temp_dir = tempfile.mkdtemp() + for path in paths: + os.mkdir(os.path.join(temp_dir, path)) + else: + temp_dir = None + + tc = unittest.TestCase() + # Only rank 0 will know about temp_dir + if get_global_rank() != 0: + tc.assertIsNone(temp_dir) + + ckpt_dirpaths = get_checkpoint_dirpaths( + temp_dir, process_group=dist.group.WORLD + ) + + # Broadcast temp_dir to verify successful execution + temp_dir = [temp_dir] if get_global_rank() == 0 else [None] + dist.broadcast_object_list(temp_dir, src=0, group=dist.group.WORLD) + temp_dir = temp_dir[0] + tc.assertIsNotNone(temp_dir) + + tc.assertEqual( + {str(x) for x in ckpt_dirpaths}, + {os.path.join(temp_dir, path) for path in paths}, + ) + + if get_global_rank() == 0: + shutil.rmtree(temp_dir) diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index a677da03ff..5438afc830 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -198,6 +198,25 @@ def __eq__(self, other: "CheckpointPath") -> bool: def __gt__(self, other: "CheckpointPath") -> bool: return self.newer_than(other) + def __getstate__(self) -> str: + # Lightweight pickling to avoid broadcast errors + return self.path + + def __setstate__(self, state: str) -> None: + # Match regex directly to avoid creating a new instance with `from_str` + path_match = self.PATH_REGEX.match(state) + assert path_match, f"Malformed checkpoint found when unpickling: {state}" + + dirpath, epoch, step, metric_name, metric_value = path_match.groups() + self.dirpath = dirpath.rstrip("/") + self.epoch = int(epoch) + self.step = int(step) + self.metric_data = ( + MetricData(name=metric_name, value=float(metric_value)) + if metric_name and metric_value + else None + ) + @rank_zero_read_and_broadcast def get_latest_checkpoint_path(