From 82b5c83b07d2321049852dd32867d6f8fcca2972 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Wed, 1 May 2024 10:03:09 -0700 Subject: [PATCH] Add CheckpointManager abstraction in utils/checkpoint.py Differential Revision: D56427226 --- tests/utils/test_checkpoint.py | 358 +++++++++++++++++++++++++++++++++ torchtnt/utils/__init__.py | 2 + torchtnt/utils/checkpoint.py | 265 +++++++++++++++++++++++- 3 files changed, 616 insertions(+), 9 deletions(-) diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index 11797d77a0..fd61afaadd 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -10,6 +10,7 @@ import shutil import tempfile import unittest +from unittest.mock import patch import torch @@ -17,6 +18,9 @@ from torch import nn from torchsnapshot import Snapshot from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME +from torchtnt.framework._test_utils import Batch +from torchtnt.framework.state import State +from torchtnt.framework.unit import TrainUnit from torchtnt.utils import get_global_rank, init_from_env from torchtnt.utils.checkpoint import ( @@ -25,6 +29,8 @@ _retrieve_checkpoint_dirpaths, _sort_by_metric_value, _sort_by_recency, + BestCheckpointConfig, + CheckpointManager, CheckpointPath, get_best_checkpoint_path, get_checkpoint_dirpaths, @@ -190,6 +196,349 @@ def test_pickling(self) -> None: self.assertEqual(unpickled, ckpt) +class CheckpointManagerTest(unittest.TestCase): + def test_create_checkpoint_manager(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + paths = [ + f"{temp_dir}/epoch_1_step_3", + f"{temp_dir}/epoch_0_step_1", + f"{temp_dir}/epoch_0_step_5_loss=-0.3", + f"{temp_dir}/epoch_1_step_1", + f"{temp_dir}/epoch_1_step_2_loss=0.5", + f"{temp_dir}/epoch_2_step_5_loss=0.3", + f"{temp_dir}/epoch_0_step_2_acc=0.7", + ] + for path in paths: + os.mkdir(path) + + # without last_n_checkpoints + ckpt_manager = CheckpointManager(temp_dir) + self.assertEqual(ckpt_manager._ckpt_paths, []) + + # with last_n_checkpoints but without metric + ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2) + self.assertEqual( + [x.path for x in ckpt_manager._ckpt_paths], + [ + f"{temp_dir}/epoch_0_step_1", + f"{temp_dir}/epoch_0_step_2_acc=0.7", + f"{temp_dir}/epoch_0_step_5_loss=-0.3", + f"{temp_dir}/epoch_1_step_1", + f"{temp_dir}/epoch_1_step_2_loss=0.5", + f"{temp_dir}/epoch_1_step_3", + f"{temp_dir}/epoch_2_step_5_loss=0.3", + ], + ) + + # with last_n_checkpoints and metric min + ckpt_manager = CheckpointManager( + temp_dir, + keep_last_n_checkpoints=3, + best_checkpoint_config=BestCheckpointConfig( + monitored_metric="loss", mode="min" + ), + ) + self.assertEqual( + [x.path for x in ckpt_manager._ckpt_paths], + [ + f"{temp_dir}/epoch_1_step_2_loss=0.5", + f"{temp_dir}/epoch_2_step_5_loss=0.3", + f"{temp_dir}/epoch_0_step_5_loss=-0.3", + ], + ) + + # with last_n_checkpoints and metric max + ckpt_manager = CheckpointManager( + temp_dir, + keep_last_n_checkpoints=3, + best_checkpoint_config=BestCheckpointConfig( + monitored_metric="loss", mode="max" + ), + ) + self.assertEqual( + [x.path for x in ckpt_manager._ckpt_paths], + [ + f"{temp_dir}/epoch_0_step_5_loss=-0.3", + f"{temp_dir}/epoch_2_step_5_loss=0.3", + f"{temp_dir}/epoch_1_step_2_loss=0.5", + ], + ) + + # with last_n_checkpoints and non previously tracked metric + ckpt_manager = CheckpointManager( + temp_dir, + keep_last_n_checkpoints=3, + best_checkpoint_config=BestCheckpointConfig( + monitored_metric="foo", mode="max" + ), + ) + self.assertEqual(ckpt_manager._ckpt_paths, []) + + @skip_if_not_distributed + def test_create_checkpoint_manager_distributed(self) -> None: + spawn_multi_process( + 2, + "gloo", + self._test_create_checkpoint_manager_distributed, + ) + + @staticmethod + def _test_create_checkpoint_manager_distributed() -> None: + if get_global_rank() == 0: + temp_dir = tempfile.mkdtemp() + paths = ["epoch_1_step_2", "epoch_0_step_1", "epoch_1_step_1"] + for path in paths: + os.mkdir(os.path.join(temp_dir, path)) + else: + temp_dir = "" + + tc = unittest.TestCase() + + # without top k config + ckpt_manager = CheckpointManager(temp_dir) + tc.assertNotEqual(ckpt_manager.dirpath, "") + tc.assertEqual(ckpt_manager._ckpt_paths, []) + + # with top k config + ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1) + tc.assertNotEqual(ckpt_manager.dirpath, "") + tc.assertEqual( + [str(x) for x in ckpt_manager._ckpt_paths], + [ + os.path.join(ckpt_manager.dirpath, path) + for path in [ + "epoch_0_step_1", + "epoch_1_step_1", + "epoch_1_step_2", + ] + ], + ) + + def test_prune_surplus_checkpoints(self) -> None: + # with checkpoints to delete + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1) + paths = [ + CheckpointPath(temp_dir, 0, 0), + CheckpointPath(temp_dir, 0, 1), + CheckpointPath(temp_dir, 1, 0), + ] + for path in paths: + os.mkdir(path.path) + + ckpt_manager._ckpt_paths = list(paths) + warning_messages = [] + expected_warning_msg = ( + f"3 checkpoints found in {temp_dir}. ", + f"Deleting {2} oldest ", + "checkpoints to enforce ``keep_last_n_checkpoints`` argument.", + ) + with patch( + f"{CheckpointManager.__module__}.logging.Logger.warning", + warning_messages.append, + ): + ckpt_manager.prune_surplus_checkpoints() + + self.assertEqual(warning_messages[0], expected_warning_msg) + self.assertEqual(ckpt_manager._ckpt_paths, [paths[2]]) + self.assertTrue(os.path.exists(paths[2].path)) + self.assertFalse(os.path.exists(paths[0].path)) + self.assertFalse(os.path.exists(paths[1].path)) + + # without checkpoints to delete + with tempfile.TemporaryDirectory() as temp_dir: + ckpt_manager = CheckpointManager(temp_dir) + paths = [ + CheckpointPath(temp_dir, 0, 0), + CheckpointPath(temp_dir, 0, 1), + CheckpointPath(temp_dir, 1, 0), + ] + ckpt_manager._ckpt_paths = list(paths) + ckpt_manager.prune_surplus_checkpoints() + self.assertEqual(ckpt_manager._ckpt_paths, paths) + + def test_generate_checkpoint_path(self) -> None: + ckpt_manager = CheckpointManager("foo") + + self.assertEqual( + ckpt_manager.generate_checkpoint_path(1, 1).path, + "foo/epoch_1_step_1", + ) + + self.assertEqual( + ckpt_manager.generate_checkpoint_path(1, 3).path, + "foo/epoch_1_step_3", + ) + + ckpt_manager._best_checkpoint_config = BestCheckpointConfig( + monitored_metric="val_loss", mode="min" + ) + self.assertEqual( + ckpt_manager.generate_checkpoint_path( + 1, 3, MetricData("val_loss", 0.5) + ).path, + "foo/epoch_1_step_3_val_loss=0.5", + ) + + # best checkpoint config, but did not pass metric data - expect path but no metric + self.assertEqual( + ckpt_manager.generate_checkpoint_path(1, 2).path, + "foo/epoch_1_step_2", + ) + + # passed metric data is tracking a different metric than best checkpoint config - expect exception + with self.assertRaisesRegex( + AssertionError, + "Attempted to get a checkpoint with metric 'mean', but best checkpoint config is for 'val_loss'", + ): + ckpt_manager.generate_checkpoint_path(1, 2, MetricData("mean", 3.5)) + + # no best checkpoint config, but passed metric data - expect exception + ckpt_manager._best_checkpoint_config = None + with self.assertRaisesRegex( + AssertionError, + "Attempted to get a checkpoint with metric but best checkpoint config is not set", + ): + ckpt_manager.generate_checkpoint_path(1, 2, MetricData("val_loss", 3.5)) + + def test_append_checkpoint_by_recency(self) -> None: + ckpt_manager = CheckpointManager("foo", keep_last_n_checkpoints=2) + ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 0)] + + # without need to remove old by recency + ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 1)) + self.assertEqual( + ckpt_manager._ckpt_paths, + [CheckpointPath("foo", 0, 0), CheckpointPath("foo", 0, 1)], + ) + + # removing old by recency + with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm: + ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 2)) + self.assertEqual( + ckpt_manager._ckpt_paths, + [CheckpointPath("foo", 0, 1), CheckpointPath("foo", 0, 2)], + ) + mock_rm.assert_called_once_with("foo/epoch_0_step_0", recursive=True) + + def test_append_checkpoint_by_metric(self) -> None: + ckpt_manager = CheckpointManager( + "foo", + keep_last_n_checkpoints=5, + best_checkpoint_config=BestCheckpointConfig( + monitored_metric="val_loss", mode="max" + ), + ) + paths = [ + CheckpointPath( + "foo", 0, x, metric_data=MetricData(name="val_loss", value=0.01 * x) + ) + for x in range(1, 7, 1) + ] + ckpt_manager._ckpt_paths = [paths[1], paths[2], paths[4]] + # without need to remove old by min metric, goes beginning + ckpt_manager.append_checkpoint(paths[0]) + self.assertEqual( + ckpt_manager._ckpt_paths, + [paths[0], paths[1], paths[2], paths[4]], + ) + # without need to remove old by min metric, goes end + ckpt_manager.append_checkpoint(paths[5]) + self.assertEqual( + ckpt_manager._ckpt_paths, + [paths[0], paths[1], paths[2], paths[4], paths[5]], + ) + # removing old max metric, goes middle + with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm: + ckpt_manager.append_checkpoint(paths[3]) + self.assertEqual( + ckpt_manager._ckpt_paths, + [paths[1], paths[2], paths[3], paths[4], paths[5]], + ) + mock_rm.assert_called_once_with( + "foo/epoch_0_step_1_val_loss=0.01", recursive=True + ) + + # no metric data - noop + ckpt_manager._keep_last_n_checkpoints = None + ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 8)) + self.assertEqual( + ckpt_manager._ckpt_paths, + [paths[1], paths[2], paths[3], paths[4], paths[5]], + ) + + def test_should_save_checkpoint(self) -> None: + """ + Tests basic functionality of should_save_checkpoint + """ + ckpt_manager = CheckpointManager("foo") + + # test default behavior + ckpt = CheckpointPath("foo", 0, 2) + self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) + + ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 1)] + self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) + ckpt_manager._keep_last_n_checkpoints = 1 + self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt)) + + ckpt_manager._ckpt_paths = [ + CheckpointPath( + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01) + ), + ] + ckpt_manager._best_checkpoint_config = BestCheckpointConfig( + monitored_metric="val_loss", + mode="min", + ) + + bigger_metric = CheckpointPath( + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.02) + ) + smaller_metric = CheckpointPath( + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.001) + ) + ckpt_manager._keep_last_n_checkpoints = None + self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) + ckpt_manager._keep_last_n_checkpoints = 1 + self.assertFalse(ckpt_manager.should_save_checkpoint(bigger_metric)) + self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric)) + ckpt_manager._keep_last_n_checkpoints = 2 + self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric)) + self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) + + # Make sure we are actually comparing against more optimal element + ckpt_manager._ckpt_paths = [ + CheckpointPath( + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01) + ), + CheckpointPath( + "foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.05) + ), + ] + + ckpt_manager._best_checkpoint_config = BestCheckpointConfig( + monitored_metric="val_loss", + mode="max", + ) + ckpt_manager._keep_last_n_checkpoints = 2 + self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric)) + + def test_remove_worst_checkpoint(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + os.mkdir(os.path.join(temp_dir, "epoch_0_step_0")) + os.mkdir(os.path.join(temp_dir, "epoch_0_step_1")) + + ckpt_manager = CheckpointManager(temp_dir) + ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 0)) + ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 1)) + + ckpt_manager.remove_checkpoint() + self.assertFalse(os.path.exists(os.path.join(temp_dir, "epoch_0_step_0"))) + self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1"))) + self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)]) + + class CheckpointUtilsTest(unittest.TestCase): @staticmethod def _create_snapshot_metadata(output_dir: str) -> None: @@ -590,3 +939,12 @@ def test_metadata_exists(self) -> None: os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) + + +class MyValLossUnit(TrainUnit[Batch]): + def __init__(self) -> None: + super().__init__() + self.val_loss = 0.01 + + def train_step(self, state: State, data: Batch) -> None: + return None diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index eac6b113e4..397ec67d26 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -8,6 +8,7 @@ from .checkpoint import ( BestCheckpointConfig, + CheckpointManager, CheckpointPath, get_best_checkpoint_path, get_checkpoint_dirpaths, @@ -93,6 +94,7 @@ "get_checkpoint_dirpaths", "get_latest_checkpoint_path", "BestCheckpointConfig", + "CheckpointManager", "copy_data_to_device", "CPUStats", "get_device_from_env", diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 5ef013cde0..a8751a8b84 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +import bisect import logging import os import re @@ -16,7 +17,7 @@ import torch.distributed as dist from fsspec.core import url_to_fs from pyre_extensions import none_throws -from torchtnt.utils.distributed import rank_zero_read_and_broadcast +from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast logger: logging.Logger = logging.getLogger(__name__) @@ -218,6 +219,240 @@ def __setstate__(self, state: str) -> None: ) +class CheckpointManager: + """ + Manage a group of CheckpointPaths that belong to the same base directory. This involves maintaining + ordering checkpoints by recency or metric value if applicable. Then, this is used to determine if a + checkpoint should be saved, and what name will be used. + + The checkpoints work in the following format: /epoch__step_ + If a metric is being tracked, it's added to the name: /epoch__step__= + + The methods in this class are meant to be used in the following order: + 1. Create instance - this will load the existing checkpoints (if any) + 2. `prune_surplus_checkpoints` - this will remove the non-optimal checkpoints to enforce the `keep_last_n_checkpoints` + 3. For every checkpointing iteration: + a. `generate_checkpoint_path`: Gives the CheckpointPath that would be saved next + b. `should_save_checkpoint`: Determines if checkpoint should be saved according to the `keep_last_n_checkpoints` and `best_checkpoint_config` + c. -- The external checkpointing API should be called if above returns True. CheckpointManager does NOT actually generate checkpoints -- + d. `append_checkpoint`: If the checkpoint was successfully saved, this should be called to update the internal state + + In general, every file system read/write operation performed by this class is executed only in rank 0, while state is synced across ranks. + """ + + def __init__( + self, + dirpath: str, + best_checkpoint_config: Optional[BestCheckpointConfig] = None, + keep_last_n_checkpoints: Optional[int] = None, + metadata_fname: Optional[str] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> None: + """ + Initialize a checkpoint manager. If a `keep_last_n_checkpoints` value is provided, this will read the + existing checkpoints in the dirpath (from rank 0 only) to account for them in the max number of checkpoints + to keep. Note that no checkpoints are deleted. + + Args: + dirpath: The base directory path that checkpoints are saved in. This is synced from rank 0 to every other rank upon initialization. + best_checkpoint_config: Optional configuration for the best checkpoint. + keep_last_n_checkpoints: Optional number of checkpoints to keep. + metadata_fname: Optional name of the metadata file. + process_group: Optional process group to use for distributed training. gloo process groups are known + to perform better. + """ + self.dirpath: str = self._sync_dirpath_to_all_ranks( + dirpath=dirpath, process_group=process_group + ) + + self._best_checkpoint_config = best_checkpoint_config + self._keep_last_n_checkpoints = keep_last_n_checkpoints + self._metadata_fname = metadata_fname + self._pg_wrapper = PGWrapper(process_group) + + self._ckpt_paths: List[CheckpointPath] = [] + if not self._keep_last_n_checkpoints: + return + + # If there is a max limit of checkpoints to store, keep track of existing ones + metric_name = ( + best_checkpoint_config.monitored_metric if best_checkpoint_config else None + ) + self._ckpt_paths = get_checkpoint_dirpaths( + dirpath=dirpath, + metadata_fname=self._metadata_fname, + metric_name=metric_name, + process_group=process_group, + ) + if best_checkpoint_config: + self._ckpt_paths.sort( + key=lambda x: x.metric_data.value, + # sort descending if min, placing worst metric at top of list + reverse=(best_checkpoint_config.mode == "min"), + ) + else: + self._ckpt_paths.sort() # Checkpoint paths are well-ordered by recency + + def prune_surplus_checkpoints(self) -> None: + """ + Prune checkpoints that exceed the maximum number of checkpoints to keep. This should be + called when training starts, so that the `keep_last_n_checkpoints` config is honored. + Files are only deleted in rank 0. + + Note: + This is not called on initialization, in case users want to inpsect previous + checkpoints. But it should be called before starting training if there is a + `keep_last_n_checkpoints` config. + + Args: + state: The training state. + """ + keep_last_n_checkpoints = self._keep_last_n_checkpoints + if keep_last_n_checkpoints and len(self._ckpt_paths) > keep_last_n_checkpoints: + logger.warning( + ( + f"{len(self._ckpt_paths)} checkpoints found in {self.dirpath}. ", + f"Deleting {len(self._ckpt_paths) - keep_last_n_checkpoints} oldest ", + "checkpoints to enforce ``keep_last_n_checkpoints`` argument.", + ) + ) + for _ in range(len(self._ckpt_paths) - keep_last_n_checkpoints): + self.remove_checkpoint() + + def generate_checkpoint_path( + self, epoch: int, step: int, metric_data: Optional[MetricData] = None + ) -> CheckpointPath: + """ + Given the current epoch, step, and possibly a metric_data value, determine the path + where it should be stored. This does not necessarily mean that the checkpoint should + be created. Instead, `should_save_checkpoint` has to be called to determine that. + + Args: + unit: The training unit. + state: The training state. + + Returns: + The path to the checkpoint to save. + + Raises: AssertionError if there is a mismatch in tracked metric, for example: + - `best_checkpoint_config` is not set but `metric_data` was provided + - `best_checkpoint_config` is set and `metric_data` is passed. But they are not tracking the same metric + """ + + if metric_data: + assert ( + self._best_checkpoint_config + ), "Attempted to get a checkpoint with metric but best checkpoint config is not set" + + assert self._best_checkpoint_config.monitored_metric == metric_data.name, ( + f"Attempted to get a checkpoint with metric '{metric_data.name}', " + f"but best checkpoint config is for '{none_throws(self._best_checkpoint_config).monitored_metric}'" + ) + + checkpoint_path = CheckpointPath( + self.dirpath, epoch, step, metric_data=metric_data + ) + + return checkpoint_path + + def should_save_checkpoint(self, checkpoint: CheckpointPath) -> bool: + """ + Given a unit and state, determine if a checkpoint should be saved when considering the `keep_last_n_checkpoints` + and `best_checkpoint_config` configs. + + Args: + checkpoint: The CheckpointPath to be potentially saved, provided by `generate_checkpoint_path`. + + Returns: + True if the checkpoint should be saved, otherwise False. + """ + + keep_last_n_checkpoints = self._keep_last_n_checkpoints + if not keep_last_n_checkpoints: + # always save candidate checkpoint if no limit of checkpoints is specified + return True + + if len(self._ckpt_paths) < keep_last_n_checkpoints: + # limit of checkpoints has not been reached + return True + + best_checkpoint_config = self._best_checkpoint_config + if not best_checkpoint_config: + # we always save the latest checkpoint + return True + + # otherwise, we need to determine if we should overwrite the worst checkpoint + return checkpoint.more_optimal_than( + self._ckpt_paths[0], mode=best_checkpoint_config.mode + ) + + def append_checkpoint(self, ckpt: CheckpointPath) -> None: + """ + This will update the internal state to keep track of the checkpoint. This function should only be called + when a checkpoint whose path was returned from `maybe_get_next_checkpoint_path` was successfully created. + If a previous checkpoint should be removed to honor `keep_last_n_checkpoint`, it will be deleted on rank 0. + + Args: + ckpt: The checkpoint to save. + state: The training state. + """ + # Remove oldest checkpoint if needed + max_ckpts = self._keep_last_n_checkpoints + if max_ckpts and len(self._ckpt_paths) >= max_ckpts: + self.remove_checkpoint() + + # If we are monitoring a metric, but the checkpoint has no metric data, we don't track it + if self._best_checkpoint_config and ckpt.metric_data: + keys = [none_throws(c.metric_data).value for c in self._ckpt_paths] + if self._best_checkpoint_config.mode == "min": + keys.reverse() + + # Use bisect.bisect() to find the insertion point + idx = bisect.bisect(keys, none_throws(ckpt.metric_data).value) + if none_throws(self._best_checkpoint_config).mode == "min": + idx = len(self._ckpt_paths) - idx + self._ckpt_paths.insert(idx, ckpt) + + elif not self._best_checkpoint_config: + # No metric tracked, most recents goes last + self._ckpt_paths.append(ckpt) + + @rank_zero_read_and_broadcast + def does_checkpoint_exist( + self, ckpt: CheckpointPath, process_group: Optional[dist.ProcessGroup] = None + ) -> bool: + """ + Checking whether a checkpoint already exists by verifying whether the optional metadata file is present in the directory. + If the checkpointer doesn't have a metadata file, this function will always return False. Check is executed in rank 0, but + result is broadcasted to all ranks. + """ + metadata_fname = self._metadata_fname + if not metadata_fname: + return False + + fs, _ = url_to_fs(self.dirpath) + return _metadata_exists(fs, ckpt.path, metadata_fname) + + @staticmethod + @rank_zero_read_and_broadcast + def _sync_dirpath_to_all_ranks( + dirpath: str, process_group: Optional[dist.ProcessGroup] = None + ) -> str: + """Synchronize the dirpath across all ranks.""" + return dirpath + + def remove_checkpoint(self) -> None: + """ + Delete the weakest checkpoint both from the internal state and from the file system (rank 0). This means: + - If there is a `best_checkpoint_config`, then the checkpoint with the least optimal metric value + - If there is no `best_checkpoint_config`, then the oldest checkpoint + """ + worst_ckpt_path = self._ckpt_paths.pop(0) + if self._pg_wrapper.get_rank() == 0: + fs, _ = url_to_fs(self.dirpath) + fs.rm(worst_ckpt_path.path, recursive=True) + + @rank_zero_read_and_broadcast def get_latest_checkpoint_path( dirpath: str, @@ -235,14 +470,15 @@ def get_latest_checkpoint_path( Raises: AssertionError if the checkpoint subdirectories are not named in the format epoch_{epoch}_step_{step}. - Note: When fetching checkpoints in a distributed environment, gloo process groups are recommended over nccl. + Note: + When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks. + gloo process groups are recommended over nccl. """ candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) if not candidate_dirpaths: return None - # Iterate through all files and directories in the specified directory latest_checkpoint = candidate_dirpaths[0] for candidate in candidate_dirpaths[1:]: if candidate.newer_than(latest_checkpoint): @@ -260,7 +496,10 @@ def get_best_checkpoint_path( process_group: Optional[dist.ProcessGroup] = None, ) -> Optional[str]: """ - Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory. + Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory based on a metric. + + The checkpoint paths are assumed to have the following format: /epoch__step__= + This will always be the case if the CheckpointManager class is used to produce their names. Args: dirpath: parent directory where checkpoints are saved. @@ -269,12 +508,13 @@ def get_best_checkpoint_path( metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) - Note: When fetching checkpoints in a distributed environment, gloo process groups are recommended over nccl. + Note: + When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks. + gloo process groups are recommended over nccl. """ dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) if not dirpaths: - # no checkpoints found return None best_checkpoint = dirpaths[0] @@ -294,7 +534,11 @@ def get_checkpoint_dirpaths( ) -> List[CheckpointPath]: """ Given a parent directory where checkpoints are saved, returns the checkpoint subdirectories. - The order of the checkpoints is not guarenteed. + The order of the checkpoints is not guaranteed. + + The checkpoint paths are assumed to have the following format: /epoch__step_ + If a metric_name is provided the format should be /epoch__step__= + This will always be the case if the CheckpointManager class is used to produce their names. Args: dirpath: parent directory where checkpoints are saved. @@ -302,7 +546,9 @@ def get_checkpoint_dirpaths( metric_name: fetches all the checkpoint directories containing the metric name only. process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) - Note: When fetching checkpoints in a distributed environment, gloo process groups are recommended over nccl. + Note: + When doing distributed training, only rank 0 will read the file system. The result will be broadcasted to all ranks. + gloo process groups are recommended over nccl. """ return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) @@ -353,7 +599,8 @@ def _retrieve_checkpoint_dirpaths( Args: dirpath: parent directory where checkpoints are saved. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. metric_name: Name of the metric that must exist in checkpoint name. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + metric_name: Name of the metric that must exist in checkpoint name. """ fs, _ = url_to_fs(dirpath)