diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index 6c4e86b53f..49be9560b3 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -10,7 +10,7 @@ import shutil import tempfile import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch import torch @@ -787,6 +787,43 @@ def test_remove_worst_checkpoint(self) -> None: self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1"))) self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)]) + @patch( + "fsspec.implementations.local.LocalFileSystem.rm", + side_effect=Exception("OSError: [Errno 2] No such file or directory"), + ) + def test_remove_worst_checkpoint_exception(self, mock_url_to_fs: MagicMock) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + os.mkdir(os.path.join(temp_dir, "epoch_0_train_step_0")) + os.mkdir(os.path.join(temp_dir, "epoch_0_train_step_1")) + + ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2) + + log_container = [] + with patch( + "torchtnt.utils.checkpoint.logging.Logger.error", log_container.append + ): + ckpt_manager.append_checkpoint( + CheckpointPath(temp_dir, 0, {Phase.TRAIN: 2}) + ) + + self.assertEqual( + log_container, + [ + ( + f"Failed to remove checkpoint '{temp_dir}/epoch_0_train_step_0' for bookkeeping purposes. " + "Do not use it to restore since it may be corrupted! Exception: OSError: [Errno 2] No such file or directory" + ) + ], + ) + # Make sure we are not tracking the oldest one anymore, even if it was not deleted + self.assertEqual( + ckpt_manager._ckpt_paths, + [ + CheckpointPath(temp_dir, 0, {Phase.TRAIN: 1}), + CheckpointPath(temp_dir, 0, {Phase.TRAIN: 2}), + ], + ) + class CheckpointUtilsTest(unittest.TestCase): @staticmethod diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 149dcb1302..e13fcf7e3d 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -487,7 +487,7 @@ def append_checkpoint(self, ckpt: CheckpointPath) -> None: ckpt: The checkpoint to save. state: The training state. """ - # Remove oldest checkpoint if needed + # Remove oldest/worst checkpoint if needed max_ckpts = self._keep_last_n_checkpoints if max_ckpts and len(self._ckpt_paths) >= max_ckpts: self.remove_checkpoint() @@ -542,7 +542,15 @@ def remove_checkpoint(self) -> None: worst_ckpt_path = self._ckpt_paths.pop(0) if self._pg_wrapper.get_rank() == 0: fs, _ = url_to_fs(self.dirpath) - fs.rm(worst_ckpt_path.path, recursive=True) + try: + fs.rm(worst_ckpt_path.path, recursive=True) + except Exception as exc: + logger.error( + ( + f"Failed to remove checkpoint '{worst_ckpt_path}' for bookkeeping purposes. " + f"Do not use it to restore since it may be corrupted! Exception: {exc}" + ) + ) @rank_zero_read_and_broadcast