Skip to content

Commit

Permalink
Don't fail job if CheckpointManager fails deleting older checkpoints (#…
Browse files Browse the repository at this point in the history
…882)

Summary: Pull Request resolved: #882

Reviewed By: anshulverma, JKSenthil

Differential Revision: D61307267

fbshipit-source-id: 4dc97353c34b6dcfbf04b374784107bc636d7f0f
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Aug 16, 2024
1 parent 123453e commit 041ebe1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
39 changes: 38 additions & 1 deletion tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import shutil
import tempfile
import unittest
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import torch

Expand Down Expand Up @@ -787,6 +787,43 @@ def test_remove_worst_checkpoint(self) -> None:
self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1")))
self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)])

@patch(
"fsspec.implementations.local.LocalFileSystem.rm",
side_effect=Exception("OSError: [Errno 2] No such file or directory"),
)
def test_remove_worst_checkpoint_exception(self, mock_url_to_fs: MagicMock) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
os.mkdir(os.path.join(temp_dir, "epoch_0_train_step_0"))
os.mkdir(os.path.join(temp_dir, "epoch_0_train_step_1"))

ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2)

log_container = []
with patch(
"torchtnt.utils.checkpoint.logging.Logger.error", log_container.append
):
ckpt_manager.append_checkpoint(
CheckpointPath(temp_dir, 0, {Phase.TRAIN: 2})
)

self.assertEqual(
log_container,
[
(
f"Failed to remove checkpoint '{temp_dir}/epoch_0_train_step_0' for bookkeeping purposes. "
"Do not use it to restore since it may be corrupted! Exception: OSError: [Errno 2] No such file or directory"
)
],
)
# Make sure we are not tracking the oldest one anymore, even if it was not deleted
self.assertEqual(
ckpt_manager._ckpt_paths,
[
CheckpointPath(temp_dir, 0, {Phase.TRAIN: 1}),
CheckpointPath(temp_dir, 0, {Phase.TRAIN: 2}),
],
)


class CheckpointUtilsTest(unittest.TestCase):
@staticmethod
Expand Down
12 changes: 10 additions & 2 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def append_checkpoint(self, ckpt: CheckpointPath) -> None:
ckpt: The checkpoint to save.
state: The training state.
"""
# Remove oldest checkpoint if needed
# Remove oldest/worst checkpoint if needed
max_ckpts = self._keep_last_n_checkpoints
if max_ckpts and len(self._ckpt_paths) >= max_ckpts:
self.remove_checkpoint()
Expand Down Expand Up @@ -542,7 +542,15 @@ def remove_checkpoint(self) -> None:
worst_ckpt_path = self._ckpt_paths.pop(0)
if self._pg_wrapper.get_rank() == 0:
fs, _ = url_to_fs(self.dirpath)
fs.rm(worst_ckpt_path.path, recursive=True)
try:
fs.rm(worst_ckpt_path.path, recursive=True)
except Exception as exc:
logger.error(
(
f"Failed to remove checkpoint '{worst_ckpt_path}' for bookkeeping purposes. "
f"Do not use it to restore since it may be corrupted! Exception: {exc}"
)
)


@rank_zero_read_and_broadcast
Expand Down

0 comments on commit 041ebe1

Please sign in to comment.