Skip to content

Commit

Permalink
Add CheckpointManager abstraction in utils/checkpoint.py
Browse files Browse the repository at this point in the history
Differential Revision: D56427226
  • Loading branch information
diego-urgell authored and facebook-github-bot committed May 1, 2024
1 parent 3f89a81 commit 9ceac42
Show file tree
Hide file tree
Showing 3 changed files with 616 additions and 9 deletions.
358 changes: 358 additions & 0 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@
import shutil
import tempfile
import unittest
from unittest.mock import patch

import torch

import torch.distributed as dist
from torch import nn
from torchsnapshot import Snapshot
from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
from torchtnt.framework._test_utils import Batch
from torchtnt.framework.state import State
from torchtnt.framework.unit import TrainUnit
from torchtnt.utils import get_global_rank, init_from_env

from torchtnt.utils.checkpoint import (
Expand All @@ -25,6 +29,8 @@
_retrieve_checkpoint_dirpaths,
_sort_by_metric_value,
_sort_by_recency,
BestCheckpointConfig,
CheckpointManager,
CheckpointPath,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
Expand Down Expand Up @@ -190,6 +196,349 @@ def test_pickling(self) -> None:
self.assertEqual(unpickled, ckpt)


class CheckpointManagerTest(unittest.TestCase):
def test_create_checkpoint_manager(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
paths = [
f"{temp_dir}/epoch_1_step_3",
f"{temp_dir}/epoch_0_step_1",
f"{temp_dir}/epoch_0_step_5_loss=-0.3",
f"{temp_dir}/epoch_1_step_1",
f"{temp_dir}/epoch_1_step_2_loss=0.5",
f"{temp_dir}/epoch_2_step_5_loss=0.3",
f"{temp_dir}/epoch_0_step_2_acc=0.7",
]
for path in paths:
os.mkdir(path)

# without last_n_checkpoints
ckpt_manager = CheckpointManager(temp_dir)
self.assertEqual(ckpt_manager._ckpt_paths, [])

# with last_n_checkpoints but without metric
ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=2)
self.assertEqual(
[x.path for x in ckpt_manager._ckpt_paths],
[
f"{temp_dir}/epoch_0_step_1",
f"{temp_dir}/epoch_0_step_2_acc=0.7",
f"{temp_dir}/epoch_0_step_5_loss=-0.3",
f"{temp_dir}/epoch_1_step_1",
f"{temp_dir}/epoch_1_step_2_loss=0.5",
f"{temp_dir}/epoch_1_step_3",
f"{temp_dir}/epoch_2_step_5_loss=0.3",
],
)

# with last_n_checkpoints and metric min
ckpt_manager = CheckpointManager(
temp_dir,
keep_last_n_checkpoints=3,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="loss", mode="min"
),
)
self.assertEqual(
[x.path for x in ckpt_manager._ckpt_paths],
[
f"{temp_dir}/epoch_1_step_2_loss=0.5",
f"{temp_dir}/epoch_2_step_5_loss=0.3",
f"{temp_dir}/epoch_0_step_5_loss=-0.3",
],
)

# with last_n_checkpoints and metric max
ckpt_manager = CheckpointManager(
temp_dir,
keep_last_n_checkpoints=3,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="loss", mode="max"
),
)
self.assertEqual(
[x.path for x in ckpt_manager._ckpt_paths],
[
f"{temp_dir}/epoch_0_step_5_loss=-0.3",
f"{temp_dir}/epoch_2_step_5_loss=0.3",
f"{temp_dir}/epoch_1_step_2_loss=0.5",
],
)

# with last_n_checkpoints and non previously tracked metric
ckpt_manager = CheckpointManager(
temp_dir,
keep_last_n_checkpoints=3,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="foo", mode="max"
),
)
self.assertEqual(ckpt_manager._ckpt_paths, [])

@skip_if_not_distributed
def test_create_checkpoint_manager_distributed(self) -> None:
spawn_multi_process(
2,
"gloo",
self._test_create_checkpoint_manager_distributed,
)

@staticmethod
def _test_create_checkpoint_manager_distributed() -> None:
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
paths = ["epoch_1_step_2", "epoch_0_step_1", "epoch_1_step_1"]
for path in paths:
os.mkdir(os.path.join(temp_dir, path))
else:
temp_dir = ""

tc = unittest.TestCase()

# without top k config
ckpt_manager = CheckpointManager(temp_dir)
tc.assertNotEqual(ckpt_manager.dirpath, "")
tc.assertEqual(ckpt_manager._ckpt_paths, [])

# with top k config
ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1)
tc.assertNotEqual(ckpt_manager.dirpath, "")
tc.assertEqual(
[str(x) for x in ckpt_manager._ckpt_paths],
[
os.path.join(ckpt_manager.dirpath, path)
for path in [
"epoch_0_step_1",
"epoch_1_step_1",
"epoch_1_step_2",
]
],
)

def test_prune_surplus_checkpoints(self) -> None:
# with checkpoints to delete
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_manager = CheckpointManager(temp_dir, keep_last_n_checkpoints=1)
paths = [
CheckpointPath(temp_dir, 0, 0),
CheckpointPath(temp_dir, 0, 1),
CheckpointPath(temp_dir, 1, 0),
]
for path in paths:
os.mkdir(path.path)

ckpt_manager._ckpt_paths = list(paths)
warning_messages = []
expected_warning_msg = (
f"3 checkpoints found in {temp_dir}. ",
f"Deleting {2} oldest ",
"checkpoints to enforce ``keep_last_n_checkpoints`` argument.",
)
with patch(
f"{CheckpointManager.__module__}.logging.Logger.warning",
warning_messages.append,
):
ckpt_manager.prune_surplus_checkpoints()

self.assertEqual(warning_messages[0], expected_warning_msg)
self.assertEqual(ckpt_manager._ckpt_paths, [paths[2]])
self.assertTrue(os.path.exists(paths[2].path))
self.assertFalse(os.path.exists(paths[0].path))
self.assertFalse(os.path.exists(paths[1].path))

# without checkpoints to delete
with tempfile.TemporaryDirectory() as temp_dir:
ckpt_manager = CheckpointManager(temp_dir)
paths = [
CheckpointPath(temp_dir, 0, 0),
CheckpointPath(temp_dir, 0, 1),
CheckpointPath(temp_dir, 1, 0),
]
ckpt_manager._ckpt_paths = list(paths)
ckpt_manager.prune_surplus_checkpoints()
self.assertEqual(ckpt_manager._ckpt_paths, paths)

def test_generate_checkpoint_path(self) -> None:
ckpt_manager = CheckpointManager("foo")

self.assertEqual(
ckpt_manager.generate_checkpoint_path(1, 1).path,
"foo/epoch_1_step_1",
)

self.assertEqual(
ckpt_manager.generate_checkpoint_path(1, 3).path,
"foo/epoch_1_step_3",
)

ckpt_manager._best_checkpoint_config = BestCheckpointConfig(
monitored_metric="val_loss", mode="min"
)
self.assertEqual(
ckpt_manager.generate_checkpoint_path(
1, 3, MetricData("val_loss", 0.5)
).path,
"foo/epoch_1_step_3_val_loss=0.5",
)

# best checkpoint config, but did not pass metric data - expect path but no metric
self.assertEqual(
ckpt_manager.generate_checkpoint_path(1, 2).path,
"foo/epoch_1_step_2",
)

# passed metric data is tracking a different metric than best checkpoint config - expect exception
with self.assertRaisesRegex(
AssertionError,
"Attempted to get a checkpoint with metric 'mean', but best checkpoint config is for 'val_loss'",
):
ckpt_manager.generate_checkpoint_path(1, 2, MetricData("mean", 3.5))

# no best checkpoint config, but passed metric data - expect exception
ckpt_manager._best_checkpoint_config = None
with self.assertRaisesRegex(
AssertionError,
"Attempted to get a checkpoint with metric but best checkpoint config is not set",
):
ckpt_manager.generate_checkpoint_path(1, 2, MetricData("val_loss", 3.5))

def test_append_checkpoint_by_recency(self) -> None:
ckpt_manager = CheckpointManager("foo", keep_last_n_checkpoints=2)
ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 0)]

# without need to remove old by recency
ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 1))
self.assertEqual(
ckpt_manager._ckpt_paths,
[CheckpointPath("foo", 0, 0), CheckpointPath("foo", 0, 1)],
)

# removing old by recency
with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm:
ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 2))
self.assertEqual(
ckpt_manager._ckpt_paths,
[CheckpointPath("foo", 0, 1), CheckpointPath("foo", 0, 2)],
)
mock_rm.assert_called_once_with("foo/epoch_0_step_0", recursive=True)

def test_append_checkpoint_by_metric(self) -> None:
ckpt_manager = CheckpointManager(
"foo",
keep_last_n_checkpoints=5,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="val_loss", mode="max"
),
)
paths = [
CheckpointPath(
"foo", 0, x, metric_data=MetricData(name="val_loss", value=0.01 * x)
)
for x in range(1, 7, 1)
]
ckpt_manager._ckpt_paths = [paths[1], paths[2], paths[4]]
# without need to remove old by min metric, goes beginning
ckpt_manager.append_checkpoint(paths[0])
self.assertEqual(
ckpt_manager._ckpt_paths,
[paths[0], paths[1], paths[2], paths[4]],
)
# without need to remove old by min metric, goes end
ckpt_manager.append_checkpoint(paths[5])
self.assertEqual(
ckpt_manager._ckpt_paths,
[paths[0], paths[1], paths[2], paths[4], paths[5]],
)
# removing old max metric, goes middle
with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm:
ckpt_manager.append_checkpoint(paths[3])
self.assertEqual(
ckpt_manager._ckpt_paths,
[paths[1], paths[2], paths[3], paths[4], paths[5]],
)
mock_rm.assert_called_once_with(
"foo/epoch_0_step_1_val_loss=0.01", recursive=True
)

# no metric data - noop
ckpt_manager._keep_last_n_checkpoints = None
ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 8))
self.assertEqual(
ckpt_manager._ckpt_paths,
[paths[1], paths[2], paths[3], paths[4], paths[5]],
)

def test_should_save_checkpoint(self) -> None:
"""
Tests basic functionality of should_save_checkpoint
"""
ckpt_manager = CheckpointManager("foo")

# test default behavior
ckpt = CheckpointPath("foo", 0, 2)
self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt))

ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 1)]
self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt))
ckpt_manager._keep_last_n_checkpoints = 1
self.assertTrue(ckpt_manager.should_save_checkpoint(ckpt))

ckpt_manager._ckpt_paths = [
CheckpointPath(
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01)
),
]
ckpt_manager._best_checkpoint_config = BestCheckpointConfig(
monitored_metric="val_loss",
mode="min",
)

bigger_metric = CheckpointPath(
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.02)
)
smaller_metric = CheckpointPath(
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.001)
)
ckpt_manager._keep_last_n_checkpoints = None
self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric))
ckpt_manager._keep_last_n_checkpoints = 1
self.assertFalse(ckpt_manager.should_save_checkpoint(bigger_metric))
self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric))
ckpt_manager._keep_last_n_checkpoints = 2
self.assertTrue(ckpt_manager.should_save_checkpoint(smaller_metric))
self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric))

# Make sure we are actually comparing against more optimal element
ckpt_manager._ckpt_paths = [
CheckpointPath(
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.01)
),
CheckpointPath(
"foo", 0, 1, metric_data=MetricData(name="val_loss", value=0.05)
),
]

ckpt_manager._best_checkpoint_config = BestCheckpointConfig(
monitored_metric="val_loss",
mode="max",
)
ckpt_manager._keep_last_n_checkpoints = 2
self.assertTrue(ckpt_manager.should_save_checkpoint(bigger_metric))

def test_remove_worst_checkpoint(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
os.mkdir(os.path.join(temp_dir, "epoch_0_step_0"))
os.mkdir(os.path.join(temp_dir, "epoch_0_step_1"))

ckpt_manager = CheckpointManager(temp_dir)
ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 0))
ckpt_manager.append_checkpoint(CheckpointPath(temp_dir, 0, 1))

ckpt_manager.remove_checkpoint()
self.assertFalse(os.path.exists(os.path.join(temp_dir, "epoch_0_step_0")))
self.assertTrue(os.path.exists(os.path.join(temp_dir, "epoch_0_step_1")))
self.assertEqual(ckpt_manager._ckpt_paths, [CheckpointPath(temp_dir, 0, 1)])


class CheckpointUtilsTest(unittest.TestCase):
@staticmethod
def _create_snapshot_metadata(output_dir: str) -> None:
Expand Down Expand Up @@ -590,3 +939,12 @@ def test_metadata_exists(self) -> None:

os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME))
self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME))


class MyValLossUnit(TrainUnit[Batch]):
def __init__(self) -> None:
super().__init__()
self.val_loss = 0.01

def train_step(self, state: State, data: Batch) -> None:
return None
Loading

0 comments on commit 9ceac42

Please sign in to comment.