diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py
index d048d499e8..f8c6aac7da 100644
--- a/tests/framework/callbacks/test_torchsnapshot_saver.py
+++ b/tests/framework/callbacks/test_torchsnapshot_saver.py
@@ -9,37 +9,31 @@
import os
import shutil
import tempfile
-import time
import unittest
from typing import Any, Dict, Iterator, List
from unittest import mock
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import MagicMock, patch
import torch
-import torch.distributed as dist
from torch import nn
from torch.utils.data import DataLoader
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
from torchtnt.framework._test_utils import (
DummyAutoUnit,
- DummyFitUnit,
DummyTrainUnit,
generate_random_dataloader,
- get_dummy_fit_state,
get_dummy_train_state,
)
-from torchtnt.framework.callbacks.lambda_callback import Lambda
+from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.callbacks.torchsnapshot_saver import (
_get_app_state,
_override_knobs,
- KnobOptions,
- RestoreOptions,
TorchSnapshotSaver,
)
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank
-from torchtnt.utils.env import init_from_env, seed
+from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import spawn_multi_process
@@ -47,99 +41,6 @@ class TorchSnapshotSaverTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()
- def test_save_every_n_train_steps(self) -> None:
- input_dim = 2
- dataset_len = 10
- batch_size = 2
- max_epochs = 2
- expected_steps_per_epoch = math.ceil(dataset_len / batch_size)
- save_every_n_train_steps = 2
-
- my_unit = DummyTrainUnit(input_dim=input_dim)
- dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
- expected_paths: List[str] = []
- with tempfile.TemporaryDirectory() as temp_dir:
- cumulative_steps = 0
- for epoch in range(max_epochs):
- for _ in range(
- save_every_n_train_steps,
- expected_steps_per_epoch + 1,
- save_every_n_train_steps,
- ):
- cumulative_steps += save_every_n_train_steps
- expected_paths.append(
- os.path.join(temp_dir, f"epoch_{epoch}_step_{cumulative_steps}")
- )
- snapshot = TorchSnapshotSaver(
- temp_dir,
- save_every_n_train_steps=save_every_n_train_steps,
- )
- # Artificially increase the step duration, otherwise torchsnapshot
- # doesn't have the time to save all snapshots and will skip some.
- slowdown = Lambda(on_train_step_end=lambda *_: time.sleep(0.1))
- train(
- my_unit,
- dataloader,
- max_epochs=max_epochs,
- callbacks=[snapshot, slowdown],
- )
- for path in expected_paths:
- self.assertTrue(os.path.exists(path) and os.path.isdir(path))
-
- def test_save_every_n_train_epochs(self) -> None:
- input_dim = 2
- dataset_len = 10
- batch_size = 2
- max_epochs = 3
- expected_steps_per_epoch = math.ceil(dataset_len / batch_size)
- save_every_n_train_epochs = 2
-
- my_unit = DummyTrainUnit(input_dim=input_dim)
- dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
- with tempfile.TemporaryDirectory() as temp_dir:
- expected_path = os.path.join(
- temp_dir,
- f"epoch_{save_every_n_train_epochs}_step_{expected_steps_per_epoch * (save_every_n_train_epochs)}",
- )
- snapshot = TorchSnapshotSaver(
- temp_dir,
- save_every_n_epochs=save_every_n_train_epochs,
- )
- train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot])
- self.assertTrue(
- os.path.exists(expected_path) and os.path.isdir(expected_path)
- )
-
- @patch.object(TorchSnapshotSaver, "_async_snapshot", autospec=True)
- def test_save_fit_entrypoint(self, mock_async_snapshot: Mock) -> None:
- input_dim = 2
-
- my_unit = DummyFitUnit(input_dim=input_dim)
- with tempfile.TemporaryDirectory() as temp_dir:
- snapshot = TorchSnapshotSaver(
- temp_dir, save_every_n_train_steps=1, save_every_n_epochs=1
- )
- train_state = get_dummy_train_state()
- fit_state = get_dummy_fit_state()
- my_unit.train_progress._num_steps_completed = 15
- my_unit.eval_progress._num_steps_completed = 10
-
- snapshot.on_train_step_end(train_state, my_unit)
- snapshot_path = mock_async_snapshot.call_args.args[1]
- self.assertIn(f"epoch_0_step_{15}", snapshot_path)
-
- snapshot.on_train_step_end(fit_state, my_unit)
- snapshot_path = mock_async_snapshot.call_args.args[1]
- self.assertIn(f"epoch_0_step_{15 + 10}", snapshot_path)
-
- snapshot.on_train_epoch_end(train_state, my_unit)
- snapshot_path = mock_async_snapshot.call_args.args[1]
- self.assertIn(f"epoch_0_step_{15}", snapshot_path)
-
- snapshot.on_train_epoch_end(fit_state, my_unit)
- snapshot_path = mock_async_snapshot.call_args.args[1]
- self.assertIn(f"epoch_0_step_{15 + 10}", snapshot_path)
-
def test_save_restore(self) -> None:
input_dim = 2
dataset_len = 10
@@ -247,10 +148,6 @@ def test_restore_from_latest(self) -> None:
)
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
- # Include a directory that does not have snapshot metadata saved
- # The restore function should skip this
- os.mkdir(os.path.join(temp_dir, "epoch_100_step_200"))
-
with mock.patch(
"torchtnt.framework.callbacks.torchsnapshot_saver.TorchSnapshotSaver.restore"
) as mock_restore:
@@ -261,27 +158,6 @@ def test_restore_from_latest(self) -> None:
)
self.assertTrue(restored)
- def test_restore_from_latest_empty_dir(self) -> None:
- input_dim = 2
- save_every_n_train_steps = 2
-
- my_unit = DummyTrainUnit(input_dim=input_dim)
- with tempfile.TemporaryDirectory() as temp_dir:
- snapshot_cb = TorchSnapshotSaver(
- temp_dir,
- save_every_n_train_steps=save_every_n_train_steps,
- )
-
- with self.assertLogs(level="WARNING") as log:
- restored = snapshot_cb.restore_from_latest(temp_dir, my_unit)
- self.assertEqual(
- log.output,
- [
- f"WARNING:torchtnt.framework.callbacks._checkpoint_utils:Input dirpath doesn't contain any subdirectories: {temp_dir}"
- ],
- )
- self.assertFalse(restored)
-
def test_save_restore_no_train_progress(self) -> None:
input_dim = 2
dataset_len = 10
@@ -352,70 +228,6 @@ def test_save_restore_no_lr_scheduler_restore(
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
self.assertIn("lr_scheduler", app_state)
- def test_save_on_train_end(self) -> None:
- input_dim = 2
- dataset_len = 10
- batch_size = 2
- max_epochs = 2
-
- expected_path = (
- f"epoch_{max_epochs}_step_{max_epochs * (dataset_len // batch_size)}"
- )
-
- my_unit = DummyTrainUnit(input_dim=input_dim)
- dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
- with tempfile.TemporaryDirectory() as temp_dir:
- self.assertFalse(os.path.exists(os.path.join(temp_dir, expected_path)))
- snapshot_cb = TorchSnapshotSaver(
- temp_dir,
- )
- train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
-
- expected_path = (
- f"epoch_{max_epochs}_step_{max_epochs * (dataset_len // batch_size)}"
- )
- self.assertTrue(os.path.exists(os.path.join(temp_dir, expected_path)))
-
- with self.assertLogs(level="WARNING") as log:
- # train again without resetting progress
- train(
- my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]
- )
- self.assertEqual(
- log.output,
- [
- "WARNING:torchtnt.framework.callbacks.torchsnapshot_saver:Final checkpoint already exists, skipping."
- ],
- )
-
- @unittest.skipUnless(
- condition=distributed_available, reason="Torch distributed is needed to run"
- )
- def test_directory_sync_collective(self) -> None:
- spawn_multi_process(
- 2,
- "gloo",
- self._directory_sync_collective,
- )
-
- @staticmethod
- def _directory_sync_collective() -> None:
- init_from_env()
- try:
- if get_global_rank() == 0:
- temp_dir = tempfile.mkdtemp()
- else:
- temp_dir = "foo"
-
- snapshot_cb = TorchSnapshotSaver(temp_dir)
- dirpath = snapshot_cb.dirpath
- tc = unittest.TestCase()
- tc.assertTrue("tmp" in dirpath)
- tc.assertFalse("foo" in dirpath)
- finally:
- if get_global_rank() == 0:
- shutil.rmtree(temp_dir) # delete temp directory
-
@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@@ -470,25 +282,6 @@ def _save_restore_fsdp() -> None:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory
- def test_saver_invalid_args(self) -> None:
- with tempfile.TemporaryDirectory() as temp_dir:
- with self.assertRaisesRegex(
- ValueError, "Invalid value passed for save_every_n_train_steps.*"
- ):
- TorchSnapshotSaver(temp_dir, save_every_n_train_steps=-2)
- with self.assertRaisesRegex(
- ValueError, "Invalid value passed for save_every_n_train_steps.*"
- ):
- TorchSnapshotSaver(temp_dir, save_every_n_train_steps=0)
- with self.assertRaisesRegex(
- ValueError, "Invalid value passed for save_every_n_epochs.*"
- ):
- TorchSnapshotSaver(temp_dir, save_every_n_epochs=-2)
- with self.assertRaisesRegex(
- ValueError, "Invalid value passed for save_every_n_epochs.*"
- ):
- TorchSnapshotSaver(temp_dir, save_every_n_epochs=0)
-
@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@@ -549,42 +342,6 @@ def _save_restore_ddp() -> None:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory
- @unittest.skipUnless(
- condition=distributed_available, reason="Torch distributed is needed to run"
- )
- @unittest.skipUnless(
- condition=cuda_available, reason="This test needs a GPU host to run."
- )
- def test_process_group_plumbing(self) -> None:
- """
- Creates a new process group and verifies that it's passed through correctly
- """
- spawn_multi_process(
- 2,
- "nccl",
- self._test_process_group_plumbing,
- )
-
- @staticmethod
- def _test_process_group_plumbing() -> None:
- new_pg = dist.new_group(backend="gloo")
-
- if get_global_rank() == 0:
- temp_dir = tempfile.mkdtemp()
- else:
- temp_dir = ""
-
- snapshot_cb = TorchSnapshotSaver(
- temp_dir,
- process_group=new_pg,
- )
- tc = unittest.TestCase()
- try:
- tc.assertEqual(snapshot_cb._process_group, new_pg)
- finally:
- if get_global_rank() == 0:
- shutil.rmtree(temp_dir) # delete temp directory
-
def test_knob_override(self) -> None:
env_var = "TORCHSNAPSHOT_MAX_PER_RANK_IO_CONCURRENCY_OVERRIDE"
knob_options = KnobOptions(max_per_rank_io_concurrency=1)
@@ -595,132 +352,6 @@ def test_knob_override(self) -> None:
with _override_knobs(KnobOptions(max_per_rank_io_concurrency=None)):
self.assertNotIn(env_var, os.environ)
- def test_should_remove_snapshot(self) -> None:
- """
- Tests the helper function that checks if snapshot should be removed or not
- """
- tss = TorchSnapshotSaver("temp")
-
- # keep_last_n_checkpoints is toggled off
- self.assertFalse(tss._should_remove_snapshot())
-
- # not enough checkpoints are saved yet to be removed
- tss._keep_last_n_checkpoints = 2
- tss._ckpt_dirpaths = ["bar"]
- self.assertFalse(tss._should_remove_snapshot())
-
- # enough checkpoints are there to remove
- tss._keep_last_n_checkpoints = 2
- tss._ckpt_dirpaths = ["foo", "bar"]
- self.assertTrue(tss._should_remove_snapshot())
-
- @patch("torchtnt.framework.callbacks.torchsnapshot_saver._delete_checkpoint")
- def test_remove_snapshot(self, mock_delete_checkpoint: MagicMock) -> None:
- """
- Tests the helper function that removes snapshots and updates the checkpoint paths
- """
- state = get_dummy_train_state()
- tss = TorchSnapshotSaver("temp")
- tss._ckpt_dirpaths = ["foo", "bar"]
- tss._remove_snapshot(state)
-
- mock_delete_checkpoint.assert_called_once()
- self.assertEqual(len(tss._ckpt_dirpaths), 1)
- self.assertEqual(tss._ckpt_dirpaths[0], "bar")
-
- @patch("torchtnt.framework.callbacks.torchsnapshot_saver._delete_checkpoint")
- def test_cleanup_surplus(self, mock_delete_checkpoint: MagicMock) -> None:
- """
- Tests surplus of checkpoints being cleaned up
- """
- state = get_dummy_train_state()
- unit = DummyTrainUnit(input_dim=2)
- warning_messages = []
- with tempfile.TemporaryDirectory() as temp_dir:
- tss = TorchSnapshotSaver(temp_dir, keep_last_n_checkpoints=1)
- tss._ckpt_dirpaths = ["foo", "bar", "baz"]
-
- expected_warning_msg = " ".join(
- [
- f"3 checkpoints found in {temp_dir}.",
- f"Deleting {2} oldest",
- "checkpoints to enforce ``keep_last_n_checkpoints`` argument.",
- ]
- )
-
- with patch(
- "torchtnt.framework.callbacks.torchsnapshot_saver.logging.Logger.warning",
- warning_messages.append,
- ):
- tss.on_train_start(state, unit)
- self.assertEqual(tss._ckpt_dirpaths, ["baz"])
- self.assertEqual(warning_messages[0], expected_warning_msg)
-
- tss = TorchSnapshotSaver(temp_dir)
- tss._ckpt_dirpaths = ["foo", "bar", "baz"]
-
- tss.on_train_start(state, unit)
- self.assertEqual(tss._ckpt_dirpaths, ["foo", "bar", "baz"])
-
- def test_keep_last_n_checkpoints(self) -> None:
- """
- Tests removing checkpoint directories
- """
- unit = DummyTrainUnit(input_dim=2)
- state = get_dummy_train_state()
- with tempfile.TemporaryDirectory() as temp_dir:
- tss = TorchSnapshotSaver(
- temp_dir,
- save_every_n_train_steps=1,
- keep_last_n_checkpoints=2,
- )
-
- # take 10 steps
- for _ in range(10):
- unit.train_progress.increment_step()
- tss.on_train_step_end(state, unit)
- # TODO remove time.sleep to avoid potential flaky test
- time.sleep(0.1) # sleep to ensure enough time to checkpoint
-
- dirs = os.listdir(temp_dir)
- self.assertEqual(len(dirs), 2)
- self.assertIn("epoch_0_step_9", dirs)
- self.assertIn("epoch_0_step_10", dirs)
-
- def test_keep_last_n_checkpoints_e2e(self) -> None:
- """
- Tests removing checkpoint directories e2e
- """
- input_dim = 2
- dataset_len = 10
- batch_size = 2
- max_epochs = 2
-
- my_unit = DummyTrainUnit(input_dim=input_dim)
- dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
- with tempfile.TemporaryDirectory() as temp_dir:
- snapshot_cb = TorchSnapshotSaver(
- temp_dir,
- save_every_n_train_steps=2,
- keep_last_n_checkpoints=1,
- )
- # Artificially increase the step duration, otherwise torchsnapshot
- # doesn't have the time to save all snapshots and will skip some.
- slowdown = Lambda(on_train_step_end=lambda *_: time.sleep(0.1))
-
- train(
- my_unit,
- dataloader,
- max_epochs=max_epochs,
- callbacks=[snapshot_cb, slowdown],
- )
- dirs = os.listdir(temp_dir)
- self.assertEqual(len(dirs), 1)
- self.assertIn(
- f"epoch_{max_epochs}_step_{dataset_len // batch_size * max_epochs}",
- os.listdir(temp_dir),
- )
-
@patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot")
def test_sync_checkpoint(self, _: MagicMock) -> None:
"""
diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py
index 187fe8c1f1..072699d602 100644
--- a/torchtnt/framework/callbacks/torchsnapshot_saver.py
+++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py
@@ -4,23 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
import logging
-import os
from contextlib import contextmanager, ExitStack
-from typing import Any, cast, Dict, Generator, Iterable, List, Optional, Set, Union
+from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union
import torch.distributed as dist
from pyre_extensions import none_throws
-from torchtnt.framework.callback import Callback
-from torchtnt.framework.callbacks._checkpoint_utils import (
- _delete_checkpoint,
- _retrieve_checkpoint_dirpaths,
- get_latest_checkpoint_path,
-)
+from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
-from torchtnt.framework.state import EntryPoint, State
+from torchtnt.framework.state import State
from torchtnt.framework.unit import (
AppStateMixin,
TEvalUnit,
@@ -29,8 +25,6 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
-from torchtnt.utils.distributed import get_global_rank, PGWrapper
-from torchtnt.utils.fsspec import get_filesystem
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import Stateful
@@ -38,11 +32,7 @@
try:
import torchsnapshot
from torchsnapshot.knobs import override_max_per_rank_io_concurrency
- from torchsnapshot.snapshot import (
- PendingSnapshot,
- Snapshot,
- SNAPSHOT_METADATA_FNAME,
- )
+ from torchsnapshot.snapshot import PendingSnapshot, Snapshot
_TStateful = torchsnapshot.Stateful
_TORCHSNAPSHOT_AVAILABLE = True
@@ -58,7 +48,7 @@
logger: logging.Logger = logging.getLogger(__name__)
-class TorchSnapshotSaver(Callback):
+class TorchSnapshotSaver(BaseCheckpointer):
"""
A callback which periodically saves the application state during training using `TorchSnapshot `_.
@@ -94,6 +84,8 @@ class TorchSnapshotSaver(Callback):
If checkpointing FSDP model, you can set state_dict type calling `set_state_dict_type `_ prior to starting training.
"""
+ metadata_fname: Optional[str] = ".snapshot_metadata"
+
def __init__(
self,
dirpath: str,
@@ -108,151 +100,27 @@ def __init__(
knob_options: Optional[KnobOptions] = None,
) -> None:
_validate_snapshot_available()
- if save_every_n_train_steps is not None and save_every_n_train_steps <= 0:
- raise ValueError(
- f"Invalid value passed for save_every_n_train_steps. Expected to receive either None or positive number, but received {save_every_n_train_steps}"
- )
- if save_every_n_epochs is not None and save_every_n_epochs <= 0:
- raise ValueError(
- f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}"
- )
- if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0:
- raise ValueError(
- f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}"
- )
- self._save_every_n_epochs = save_every_n_epochs
- self._save_every_n_train_steps = save_every_n_train_steps
-
- self._keep_last_n_checkpoints = keep_last_n_checkpoints
- self._ckpt_dirpaths: List[str] = []
- if self._keep_last_n_checkpoints:
- self._ckpt_dirpaths = _retrieve_checkpoint_dirpaths(dirpath)
-
- self._process_group = process_group
- self._pg_wrapper = PGWrapper(process_group)
- self._sync_dirpath_to_all_ranks(dirpath)
+ super().__init__(
+ dirpath=dirpath,
+ save_every_n_train_steps=save_every_n_train_steps,
+ save_every_n_epochs=save_every_n_epochs,
+ keep_last_n_checkpoints=keep_last_n_checkpoints,
+ process_group=process_group,
+ )
self._async_checkpoint = async_checkpoint
+
self._replicated: Set[str] = set(replicated or [])
self._prev_snapshot: Optional[PendingSnapshot] = None
self._storage_options = storage_options
self._knob_options: KnobOptions = knob_options or KnobOptions()
- def _sync_dirpath_to_all_ranks(self, dirpath: str) -> None:
- if not (dist.is_available() and dist.is_initialized()):
- self._dirpath: str = dirpath
- return
-
- dirpath_container = [dirpath] if get_global_rank() == 0 else [""]
- # broadcast directory from global rank 0
- dist.broadcast_object_list(dirpath_container, src=0, group=self._process_group)
- updated_dirpath = dirpath_container[0]
- if updated_dirpath != dirpath:
- logger.warning(f"Updating dirpath to match rank 0: {updated_dirpath}")
-
- self._dirpath: str = updated_dirpath
-
- @property
- def dirpath(self) -> str:
- """Returns parent directory to save to."""
- return self._dirpath
-
def on_train_start(self, state: State, unit: TTrainUnit) -> None:
"""Validate there's no key collision for the app state."""
app_state = _app_state(unit)
_check_app_state_collision(app_state)
- # clean up the difference if surplus of checkpoints exist
- keep_last_n_checkpoints = self._keep_last_n_checkpoints
- if (
- keep_last_n_checkpoints
- and len(self._ckpt_dirpaths) > keep_last_n_checkpoints
- ):
- logger.warning(
- " ".join(
- [
- f"{len(self._ckpt_dirpaths)} checkpoints found in {self._dirpath}.",
- f"Deleting {len(self._ckpt_dirpaths) - keep_last_n_checkpoints} oldest",
- "checkpoints to enforce ``keep_last_n_checkpoints`` argument.",
- ]
- )
- )
- for _ in range(len(self._ckpt_dirpaths) - keep_last_n_checkpoints):
- self._remove_snapshot(state)
-
- def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
- num_steps_completed = unit.train_progress.num_steps_completed
- save_every_n_train_steps = self._save_every_n_train_steps
- if (
- save_every_n_train_steps is None
- or num_steps_completed % save_every_n_train_steps != 0
- ):
- return
-
- epoch = unit.train_progress.num_epochs_completed
- if state.entry_point == EntryPoint.FIT:
- num_steps_completed += cast(
- TEvalUnit, unit
- ).eval_progress.num_steps_completed
- snapshot_path = _get_snapshot_save_path(
- self._dirpath, epoch, num_steps_completed
- )
- self._checkpoint_impl(
- state,
- unit,
- snapshot_path=snapshot_path,
- intra_epoch=True,
- prev_snapshot_wait=False,
- curr_snapshot_wait=False,
- )
-
- def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
- epoch = unit.train_progress.num_epochs_completed
- save_every_n_epochs = self._save_every_n_epochs
- if save_every_n_epochs is None or epoch % save_every_n_epochs != 0:
- return
-
- num_steps_completed = unit.train_progress.num_steps_completed
- if state.entry_point == EntryPoint.FIT:
- num_steps_completed += cast(
- TEvalUnit, unit
- ).eval_progress.num_steps_completed
- snapshot_path = _get_snapshot_save_path(
- self._dirpath, epoch, num_steps_completed
- )
- self._checkpoint_impl(
- state,
- unit,
- snapshot_path=snapshot_path,
- intra_epoch=False,
- prev_snapshot_wait=True,
- curr_snapshot_wait=False,
- )
-
- def on_train_end(self, state: State, unit: TTrainUnit) -> None:
- num_steps_completed = unit.train_progress.num_steps_completed
- if state.entry_point == EntryPoint.FIT:
- num_steps_completed += cast(
- TEvalUnit, unit
- ).eval_progress.num_steps_completed
- epoch = unit.train_progress.num_epochs_completed
- snapshot_path = _get_snapshot_save_path(
- self._dirpath, epoch, num_steps_completed
- )
-
- fs = get_filesystem(snapshot_path)
- if fs.exists(os.path.join(snapshot_path, SNAPSHOT_METADATA_FNAME)):
- rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger)
- return
-
- self._checkpoint_impl(
- state,
- unit,
- snapshot_path=snapshot_path,
- intra_epoch=False,
- prev_snapshot_wait=True,
- curr_snapshot_wait=True,
- )
+ super().on_train_start(state, unit)
def on_exception(
self,
@@ -267,22 +135,25 @@ def _checkpoint_impl(
state: State,
unit: AppStateMixin,
*,
- snapshot_path: str,
- intra_epoch: bool,
- prev_snapshot_wait: bool,
- curr_snapshot_wait: bool,
- ) -> None:
+ checkpoint_path: str,
+ hook: str,
+ ) -> bool:
"""
Checkpoint the current state of the application.
-
- Args:
- state: State of the application
- unit: The training/evaluation/prediction unit
- snapshot_path: Path to save the snapshot
- intra_epoch: Whether in middle of epoch or not
- prev_snapshot_wait: Whether to wait for previous snapshot to finish writing
- curr_snapshot_wait: Whether to wait for current snapshot to finish writing
"""
+ intra_epoch = False
+ prev_snapshot_wait = False
+ curr_snapshot_wait = False
+ if hook == "on_train_step_end":
+ intra_epoch = True
+ elif hook == "on_train_epoch_end":
+ prev_snapshot_wait = True
+ elif hook == "on_train_end":
+ prev_snapshot_wait = True
+ curr_snapshot_wait = True
+ else:
+ raise RuntimeError(f"Unexpected hook encountered '{hook}'")
+
app_state = _get_app_state(
state,
unit,
@@ -297,39 +168,14 @@ def _checkpoint_impl(
# future, add logic to set successful flag
# only when checkpoint is fully written
checkpoint_success = self._async_snapshot(
- snapshot_path, app_state, wait=prev_snapshot_wait
+ checkpoint_path, app_state, wait=prev_snapshot_wait
)
if curr_snapshot_wait:
self._wait()
else:
- with get_timing_context(
- state, f"{self.__class__.__name__}.take_sync_snapshot"
- ):
- checkpoint_success = self._sync_snapshot(
- snapshot_path, app_state, wait=prev_snapshot_wait
- )
-
- # remove and book keep snapshots related to keep_last_n_checkpoints
- if checkpoint_success:
- if self._should_remove_snapshot():
- self._remove_snapshot(state)
- self._ckpt_dirpaths.append(snapshot_path)
-
- def _should_remove_snapshot(self) -> bool:
- keep_last_n_checkpoints = self._keep_last_n_checkpoints
- return (
- keep_last_n_checkpoints is not None
- and len(self._ckpt_dirpaths) >= keep_last_n_checkpoints
- )
-
- def _remove_snapshot(self, state: State) -> None:
- # remove oldest snapshot directory
- oldest_ckpt_path = self._ckpt_dirpaths.pop(0)
- with get_timing_context(state, f"{self.__class__.__name__}.delete_snapshot"):
- if self._pg_wrapper.get_rank() == 0:
- # only delete on rank 0
- _delete_checkpoint(oldest_ckpt_path, SNAPSHOT_METADATA_FNAME)
- self._pg_wrapper.barrier()
+ with get_timing_context(state, f"{self.__class__.__name__}.take_snapshot"):
+ checkpoint_success = self._sync_snapshot(checkpoint_path, app_state)
+ return checkpoint_success
def _wait(self) -> None:
if self._prev_snapshot is not None:
@@ -366,7 +212,9 @@ def _async_snapshot(
return True
def _sync_snapshot(
- self, snapshot_path: str, app_state: Dict[str, _TStateful], *, wait: bool
+ self,
+ snapshot_path: str,
+ app_state: Dict[str, _TStateful],
) -> bool:
with _override_knobs(self._knob_options):
rank_zero_info(
@@ -467,52 +315,6 @@ def restore(
snapshot.restore(app_state)
rank_zero_info(f"Restored snapshot from path: {path}", logger=logger)
- @staticmethod
- def restore_from_latest(
- dirpath: str,
- unit: AppStateMixin,
- *,
- train_dataloader: Optional[Iterable[TTrainData]] = None,
- process_group: Optional[dist.ProcessGroup] = None,
- restore_options: Optional[RestoreOptions] = None,
- storage_options: Optional[Dict[str, Any]] = None,
- knob_options: Optional[KnobOptions] = None,
- ) -> bool:
- """
- Given a parent directory where checkpoints are saved, restore the snapshot state from the latest checkpoint in the directory.
-
- There are additional flags offered should the user want to skip loading the train and eval progress.
- By default, the train and eval progress are restored, if applicable.
-
- Args:
- dirpath: Parent directory from which to get the latest snapshot.
- unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
- train_dataloader: An optional train dataloader to restore.
- process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
- restore_options: Controls what to filter when restoring the state.
- storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations.
- knob_options: Additional keyword options for the snapshot knobs
-
- Returns:
- True if the latest snapshot directory was found and successfully restored, otherwise False.
- """
- path = get_latest_checkpoint_path(
- dirpath, SNAPSHOT_METADATA_FNAME, process_group=process_group
- )
- if path is None:
- return False
- logger.info(f"Restoring from path: {path}")
- TorchSnapshotSaver.restore(
- path,
- unit,
- train_dataloader=train_dataloader,
- process_group=process_group,
- restore_options=restore_options,
- storage_options=storage_options,
- knob_options=knob_options,
- )
- return True
-
def _validate_snapshot_available() -> None:
if not _TORCHSNAPSHOT_AVAILABLE:
@@ -523,11 +325,6 @@ def _validate_snapshot_available() -> None:
)
-def _get_snapshot_save_path(dirpath: str, epoch: int, step: int) -> str:
- # TODO: discuss whether this path should be customized
- return os.path.join(dirpath, f"epoch_{epoch}_step_{step}")
-
-
def _app_state(unit: AppStateMixin) -> Dict[str, Any]:
"""Join together all of the tracked stateful entities to simplify registration of snapshottable states, deals with FSDP case"""
app_state = unit.app_state()