diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index d048d499e8..f8c6aac7da 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -9,37 +9,31 @@ import os import shutil import tempfile -import time import unittest from typing import Any, Dict, Iterator, List from unittest import mock -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import torch -import torch.distributed as dist from torch import nn from torch.utils.data import DataLoader from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq from torchtnt.framework._test_utils import ( DummyAutoUnit, - DummyFitUnit, DummyTrainUnit, generate_random_dataloader, - get_dummy_fit_state, get_dummy_train_state, ) -from torchtnt.framework.callbacks.lambda_callback import Lambda +from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions from torchtnt.framework.callbacks.torchsnapshot_saver import ( _get_app_state, _override_knobs, - KnobOptions, - RestoreOptions, TorchSnapshotSaver, ) from torchtnt.framework.train import train from torchtnt.utils.distributed import get_global_rank -from torchtnt.utils.env import init_from_env, seed +from torchtnt.utils.env import seed from torchtnt.utils.test_utils import spawn_multi_process @@ -47,99 +41,6 @@ class TorchSnapshotSaverTest(unittest.TestCase): cuda_available: bool = torch.cuda.is_available() distributed_available: bool = torch.distributed.is_available() - def test_save_every_n_train_steps(self) -> None: - input_dim = 2 - dataset_len = 10 - batch_size = 2 - max_epochs = 2 - expected_steps_per_epoch = math.ceil(dataset_len / batch_size) - save_every_n_train_steps = 2 - - my_unit = DummyTrainUnit(input_dim=input_dim) - dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) - expected_paths: List[str] = [] - with tempfile.TemporaryDirectory() as temp_dir: - cumulative_steps = 0 - for epoch in range(max_epochs): - for _ in range( - save_every_n_train_steps, - expected_steps_per_epoch + 1, - save_every_n_train_steps, - ): - cumulative_steps += save_every_n_train_steps - expected_paths.append( - os.path.join(temp_dir, f"epoch_{epoch}_step_{cumulative_steps}") - ) - snapshot = TorchSnapshotSaver( - temp_dir, - save_every_n_train_steps=save_every_n_train_steps, - ) - # Artificially increase the step duration, otherwise torchsnapshot - # doesn't have the time to save all snapshots and will skip some. - slowdown = Lambda(on_train_step_end=lambda *_: time.sleep(0.1)) - train( - my_unit, - dataloader, - max_epochs=max_epochs, - callbacks=[snapshot, slowdown], - ) - for path in expected_paths: - self.assertTrue(os.path.exists(path) and os.path.isdir(path)) - - def test_save_every_n_train_epochs(self) -> None: - input_dim = 2 - dataset_len = 10 - batch_size = 2 - max_epochs = 3 - expected_steps_per_epoch = math.ceil(dataset_len / batch_size) - save_every_n_train_epochs = 2 - - my_unit = DummyTrainUnit(input_dim=input_dim) - dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) - with tempfile.TemporaryDirectory() as temp_dir: - expected_path = os.path.join( - temp_dir, - f"epoch_{save_every_n_train_epochs}_step_{expected_steps_per_epoch * (save_every_n_train_epochs)}", - ) - snapshot = TorchSnapshotSaver( - temp_dir, - save_every_n_epochs=save_every_n_train_epochs, - ) - train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot]) - self.assertTrue( - os.path.exists(expected_path) and os.path.isdir(expected_path) - ) - - @patch.object(TorchSnapshotSaver, "_async_snapshot", autospec=True) - def test_save_fit_entrypoint(self, mock_async_snapshot: Mock) -> None: - input_dim = 2 - - my_unit = DummyFitUnit(input_dim=input_dim) - with tempfile.TemporaryDirectory() as temp_dir: - snapshot = TorchSnapshotSaver( - temp_dir, save_every_n_train_steps=1, save_every_n_epochs=1 - ) - train_state = get_dummy_train_state() - fit_state = get_dummy_fit_state() - my_unit.train_progress._num_steps_completed = 15 - my_unit.eval_progress._num_steps_completed = 10 - - snapshot.on_train_step_end(train_state, my_unit) - snapshot_path = mock_async_snapshot.call_args.args[1] - self.assertIn(f"epoch_0_step_{15}", snapshot_path) - - snapshot.on_train_step_end(fit_state, my_unit) - snapshot_path = mock_async_snapshot.call_args.args[1] - self.assertIn(f"epoch_0_step_{15 + 10}", snapshot_path) - - snapshot.on_train_epoch_end(train_state, my_unit) - snapshot_path = mock_async_snapshot.call_args.args[1] - self.assertIn(f"epoch_0_step_{15}", snapshot_path) - - snapshot.on_train_epoch_end(fit_state, my_unit) - snapshot_path = mock_async_snapshot.call_args.args[1] - self.assertIn(f"epoch_0_step_{15 + 10}", snapshot_path) - def test_save_restore(self) -> None: input_dim = 2 dataset_len = 10 @@ -247,10 +148,6 @@ def test_restore_from_latest(self) -> None: ) train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]) - # Include a directory that does not have snapshot metadata saved - # The restore function should skip this - os.mkdir(os.path.join(temp_dir, "epoch_100_step_200")) - with mock.patch( "torchtnt.framework.callbacks.torchsnapshot_saver.TorchSnapshotSaver.restore" ) as mock_restore: @@ -261,27 +158,6 @@ def test_restore_from_latest(self) -> None: ) self.assertTrue(restored) - def test_restore_from_latest_empty_dir(self) -> None: - input_dim = 2 - save_every_n_train_steps = 2 - - my_unit = DummyTrainUnit(input_dim=input_dim) - with tempfile.TemporaryDirectory() as temp_dir: - snapshot_cb = TorchSnapshotSaver( - temp_dir, - save_every_n_train_steps=save_every_n_train_steps, - ) - - with self.assertLogs(level="WARNING") as log: - restored = snapshot_cb.restore_from_latest(temp_dir, my_unit) - self.assertEqual( - log.output, - [ - f"WARNING:torchtnt.framework.callbacks._checkpoint_utils:Input dirpath doesn't contain any subdirectories: {temp_dir}" - ], - ) - self.assertFalse(restored) - def test_save_restore_no_train_progress(self) -> None: input_dim = 2 dataset_len = 10 @@ -352,70 +228,6 @@ def test_save_restore_no_lr_scheduler_restore( app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] self.assertIn("lr_scheduler", app_state) - def test_save_on_train_end(self) -> None: - input_dim = 2 - dataset_len = 10 - batch_size = 2 - max_epochs = 2 - - expected_path = ( - f"epoch_{max_epochs}_step_{max_epochs * (dataset_len // batch_size)}" - ) - - my_unit = DummyTrainUnit(input_dim=input_dim) - dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) - with tempfile.TemporaryDirectory() as temp_dir: - self.assertFalse(os.path.exists(os.path.join(temp_dir, expected_path))) - snapshot_cb = TorchSnapshotSaver( - temp_dir, - ) - train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]) - - expected_path = ( - f"epoch_{max_epochs}_step_{max_epochs * (dataset_len // batch_size)}" - ) - self.assertTrue(os.path.exists(os.path.join(temp_dir, expected_path))) - - with self.assertLogs(level="WARNING") as log: - # train again without resetting progress - train( - my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb] - ) - self.assertEqual( - log.output, - [ - "WARNING:torchtnt.framework.callbacks.torchsnapshot_saver:Final checkpoint already exists, skipping." - ], - ) - - @unittest.skipUnless( - condition=distributed_available, reason="Torch distributed is needed to run" - ) - def test_directory_sync_collective(self) -> None: - spawn_multi_process( - 2, - "gloo", - self._directory_sync_collective, - ) - - @staticmethod - def _directory_sync_collective() -> None: - init_from_env() - try: - if get_global_rank() == 0: - temp_dir = tempfile.mkdtemp() - else: - temp_dir = "foo" - - snapshot_cb = TorchSnapshotSaver(temp_dir) - dirpath = snapshot_cb.dirpath - tc = unittest.TestCase() - tc.assertTrue("tmp" in dirpath) - tc.assertFalse("foo" in dirpath) - finally: - if get_global_rank() == 0: - shutil.rmtree(temp_dir) # delete temp directory - @unittest.skipUnless( condition=distributed_available, reason="Torch distributed is needed to run" ) @@ -470,25 +282,6 @@ def _save_restore_fsdp() -> None: if get_global_rank() == 0: shutil.rmtree(temp_dir) # delete temp directory - def test_saver_invalid_args(self) -> None: - with tempfile.TemporaryDirectory() as temp_dir: - with self.assertRaisesRegex( - ValueError, "Invalid value passed for save_every_n_train_steps.*" - ): - TorchSnapshotSaver(temp_dir, save_every_n_train_steps=-2) - with self.assertRaisesRegex( - ValueError, "Invalid value passed for save_every_n_train_steps.*" - ): - TorchSnapshotSaver(temp_dir, save_every_n_train_steps=0) - with self.assertRaisesRegex( - ValueError, "Invalid value passed for save_every_n_epochs.*" - ): - TorchSnapshotSaver(temp_dir, save_every_n_epochs=-2) - with self.assertRaisesRegex( - ValueError, "Invalid value passed for save_every_n_epochs.*" - ): - TorchSnapshotSaver(temp_dir, save_every_n_epochs=0) - @unittest.skipUnless( condition=distributed_available, reason="Torch distributed is needed to run" ) @@ -549,42 +342,6 @@ def _save_restore_ddp() -> None: if get_global_rank() == 0: shutil.rmtree(temp_dir) # delete temp directory - @unittest.skipUnless( - condition=distributed_available, reason="Torch distributed is needed to run" - ) - @unittest.skipUnless( - condition=cuda_available, reason="This test needs a GPU host to run." - ) - def test_process_group_plumbing(self) -> None: - """ - Creates a new process group and verifies that it's passed through correctly - """ - spawn_multi_process( - 2, - "nccl", - self._test_process_group_plumbing, - ) - - @staticmethod - def _test_process_group_plumbing() -> None: - new_pg = dist.new_group(backend="gloo") - - if get_global_rank() == 0: - temp_dir = tempfile.mkdtemp() - else: - temp_dir = "" - - snapshot_cb = TorchSnapshotSaver( - temp_dir, - process_group=new_pg, - ) - tc = unittest.TestCase() - try: - tc.assertEqual(snapshot_cb._process_group, new_pg) - finally: - if get_global_rank() == 0: - shutil.rmtree(temp_dir) # delete temp directory - def test_knob_override(self) -> None: env_var = "TORCHSNAPSHOT_MAX_PER_RANK_IO_CONCURRENCY_OVERRIDE" knob_options = KnobOptions(max_per_rank_io_concurrency=1) @@ -595,132 +352,6 @@ def test_knob_override(self) -> None: with _override_knobs(KnobOptions(max_per_rank_io_concurrency=None)): self.assertNotIn(env_var, os.environ) - def test_should_remove_snapshot(self) -> None: - """ - Tests the helper function that checks if snapshot should be removed or not - """ - tss = TorchSnapshotSaver("temp") - - # keep_last_n_checkpoints is toggled off - self.assertFalse(tss._should_remove_snapshot()) - - # not enough checkpoints are saved yet to be removed - tss._keep_last_n_checkpoints = 2 - tss._ckpt_dirpaths = ["bar"] - self.assertFalse(tss._should_remove_snapshot()) - - # enough checkpoints are there to remove - tss._keep_last_n_checkpoints = 2 - tss._ckpt_dirpaths = ["foo", "bar"] - self.assertTrue(tss._should_remove_snapshot()) - - @patch("torchtnt.framework.callbacks.torchsnapshot_saver._delete_checkpoint") - def test_remove_snapshot(self, mock_delete_checkpoint: MagicMock) -> None: - """ - Tests the helper function that removes snapshots and updates the checkpoint paths - """ - state = get_dummy_train_state() - tss = TorchSnapshotSaver("temp") - tss._ckpt_dirpaths = ["foo", "bar"] - tss._remove_snapshot(state) - - mock_delete_checkpoint.assert_called_once() - self.assertEqual(len(tss._ckpt_dirpaths), 1) - self.assertEqual(tss._ckpt_dirpaths[0], "bar") - - @patch("torchtnt.framework.callbacks.torchsnapshot_saver._delete_checkpoint") - def test_cleanup_surplus(self, mock_delete_checkpoint: MagicMock) -> None: - """ - Tests surplus of checkpoints being cleaned up - """ - state = get_dummy_train_state() - unit = DummyTrainUnit(input_dim=2) - warning_messages = [] - with tempfile.TemporaryDirectory() as temp_dir: - tss = TorchSnapshotSaver(temp_dir, keep_last_n_checkpoints=1) - tss._ckpt_dirpaths = ["foo", "bar", "baz"] - - expected_warning_msg = " ".join( - [ - f"3 checkpoints found in {temp_dir}.", - f"Deleting {2} oldest", - "checkpoints to enforce ``keep_last_n_checkpoints`` argument.", - ] - ) - - with patch( - "torchtnt.framework.callbacks.torchsnapshot_saver.logging.Logger.warning", - warning_messages.append, - ): - tss.on_train_start(state, unit) - self.assertEqual(tss._ckpt_dirpaths, ["baz"]) - self.assertEqual(warning_messages[0], expected_warning_msg) - - tss = TorchSnapshotSaver(temp_dir) - tss._ckpt_dirpaths = ["foo", "bar", "baz"] - - tss.on_train_start(state, unit) - self.assertEqual(tss._ckpt_dirpaths, ["foo", "bar", "baz"]) - - def test_keep_last_n_checkpoints(self) -> None: - """ - Tests removing checkpoint directories - """ - unit = DummyTrainUnit(input_dim=2) - state = get_dummy_train_state() - with tempfile.TemporaryDirectory() as temp_dir: - tss = TorchSnapshotSaver( - temp_dir, - save_every_n_train_steps=1, - keep_last_n_checkpoints=2, - ) - - # take 10 steps - for _ in range(10): - unit.train_progress.increment_step() - tss.on_train_step_end(state, unit) - # TODO remove time.sleep to avoid potential flaky test - time.sleep(0.1) # sleep to ensure enough time to checkpoint - - dirs = os.listdir(temp_dir) - self.assertEqual(len(dirs), 2) - self.assertIn("epoch_0_step_9", dirs) - self.assertIn("epoch_0_step_10", dirs) - - def test_keep_last_n_checkpoints_e2e(self) -> None: - """ - Tests removing checkpoint directories e2e - """ - input_dim = 2 - dataset_len = 10 - batch_size = 2 - max_epochs = 2 - - my_unit = DummyTrainUnit(input_dim=input_dim) - dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) - with tempfile.TemporaryDirectory() as temp_dir: - snapshot_cb = TorchSnapshotSaver( - temp_dir, - save_every_n_train_steps=2, - keep_last_n_checkpoints=1, - ) - # Artificially increase the step duration, otherwise torchsnapshot - # doesn't have the time to save all snapshots and will skip some. - slowdown = Lambda(on_train_step_end=lambda *_: time.sleep(0.1)) - - train( - my_unit, - dataloader, - max_epochs=max_epochs, - callbacks=[snapshot_cb, slowdown], - ) - dirs = os.listdir(temp_dir) - self.assertEqual(len(dirs), 1) - self.assertIn( - f"epoch_{max_epochs}_step_{dataset_len // batch_size * max_epochs}", - os.listdir(temp_dir), - ) - @patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot") def test_sync_checkpoint(self, _: MagicMock) -> None: """ diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 187fe8c1f1..072699d602 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -4,23 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import logging -import os from contextlib import contextmanager, ExitStack -from typing import Any, cast, Dict, Generator, Iterable, List, Optional, Set, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Union import torch.distributed as dist from pyre_extensions import none_throws -from torchtnt.framework.callback import Callback -from torchtnt.framework.callbacks._checkpoint_utils import ( - _delete_checkpoint, - _retrieve_checkpoint_dirpaths, - get_latest_checkpoint_path, -) +from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions -from torchtnt.framework.state import EntryPoint, State +from torchtnt.framework.state import State from torchtnt.framework.unit import ( AppStateMixin, TEvalUnit, @@ -29,8 +25,6 @@ TTrainUnit, ) from torchtnt.framework.utils import get_timing_context -from torchtnt.utils.distributed import get_global_rank, PGWrapper -from torchtnt.utils.fsspec import get_filesystem from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import Stateful @@ -38,11 +32,7 @@ try: import torchsnapshot from torchsnapshot.knobs import override_max_per_rank_io_concurrency - from torchsnapshot.snapshot import ( - PendingSnapshot, - Snapshot, - SNAPSHOT_METADATA_FNAME, - ) + from torchsnapshot.snapshot import PendingSnapshot, Snapshot _TStateful = torchsnapshot.Stateful _TORCHSNAPSHOT_AVAILABLE = True @@ -58,7 +48,7 @@ logger: logging.Logger = logging.getLogger(__name__) -class TorchSnapshotSaver(Callback): +class TorchSnapshotSaver(BaseCheckpointer): """ A callback which periodically saves the application state during training using `TorchSnapshot `_. @@ -94,6 +84,8 @@ class TorchSnapshotSaver(Callback): If checkpointing FSDP model, you can set state_dict type calling `set_state_dict_type `_ prior to starting training. """ + metadata_fname: Optional[str] = ".snapshot_metadata" + def __init__( self, dirpath: str, @@ -108,151 +100,27 @@ def __init__( knob_options: Optional[KnobOptions] = None, ) -> None: _validate_snapshot_available() - if save_every_n_train_steps is not None and save_every_n_train_steps <= 0: - raise ValueError( - f"Invalid value passed for save_every_n_train_steps. Expected to receive either None or positive number, but received {save_every_n_train_steps}" - ) - if save_every_n_epochs is not None and save_every_n_epochs <= 0: - raise ValueError( - f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}" - ) - if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0: - raise ValueError( - f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}" - ) - self._save_every_n_epochs = save_every_n_epochs - self._save_every_n_train_steps = save_every_n_train_steps - - self._keep_last_n_checkpoints = keep_last_n_checkpoints - self._ckpt_dirpaths: List[str] = [] - if self._keep_last_n_checkpoints: - self._ckpt_dirpaths = _retrieve_checkpoint_dirpaths(dirpath) - - self._process_group = process_group - self._pg_wrapper = PGWrapper(process_group) - self._sync_dirpath_to_all_ranks(dirpath) + super().__init__( + dirpath=dirpath, + save_every_n_train_steps=save_every_n_train_steps, + save_every_n_epochs=save_every_n_epochs, + keep_last_n_checkpoints=keep_last_n_checkpoints, + process_group=process_group, + ) self._async_checkpoint = async_checkpoint + self._replicated: Set[str] = set(replicated or []) self._prev_snapshot: Optional[PendingSnapshot] = None self._storage_options = storage_options self._knob_options: KnobOptions = knob_options or KnobOptions() - def _sync_dirpath_to_all_ranks(self, dirpath: str) -> None: - if not (dist.is_available() and dist.is_initialized()): - self._dirpath: str = dirpath - return - - dirpath_container = [dirpath] if get_global_rank() == 0 else [""] - # broadcast directory from global rank 0 - dist.broadcast_object_list(dirpath_container, src=0, group=self._process_group) - updated_dirpath = dirpath_container[0] - if updated_dirpath != dirpath: - logger.warning(f"Updating dirpath to match rank 0: {updated_dirpath}") - - self._dirpath: str = updated_dirpath - - @property - def dirpath(self) -> str: - """Returns parent directory to save to.""" - return self._dirpath - def on_train_start(self, state: State, unit: TTrainUnit) -> None: """Validate there's no key collision for the app state.""" app_state = _app_state(unit) _check_app_state_collision(app_state) - # clean up the difference if surplus of checkpoints exist - keep_last_n_checkpoints = self._keep_last_n_checkpoints - if ( - keep_last_n_checkpoints - and len(self._ckpt_dirpaths) > keep_last_n_checkpoints - ): - logger.warning( - " ".join( - [ - f"{len(self._ckpt_dirpaths)} checkpoints found in {self._dirpath}.", - f"Deleting {len(self._ckpt_dirpaths) - keep_last_n_checkpoints} oldest", - "checkpoints to enforce ``keep_last_n_checkpoints`` argument.", - ] - ) - ) - for _ in range(len(self._ckpt_dirpaths) - keep_last_n_checkpoints): - self._remove_snapshot(state) - - def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: - num_steps_completed = unit.train_progress.num_steps_completed - save_every_n_train_steps = self._save_every_n_train_steps - if ( - save_every_n_train_steps is None - or num_steps_completed % save_every_n_train_steps != 0 - ): - return - - epoch = unit.train_progress.num_epochs_completed - if state.entry_point == EntryPoint.FIT: - num_steps_completed += cast( - TEvalUnit, unit - ).eval_progress.num_steps_completed - snapshot_path = _get_snapshot_save_path( - self._dirpath, epoch, num_steps_completed - ) - self._checkpoint_impl( - state, - unit, - snapshot_path=snapshot_path, - intra_epoch=True, - prev_snapshot_wait=False, - curr_snapshot_wait=False, - ) - - def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: - epoch = unit.train_progress.num_epochs_completed - save_every_n_epochs = self._save_every_n_epochs - if save_every_n_epochs is None or epoch % save_every_n_epochs != 0: - return - - num_steps_completed = unit.train_progress.num_steps_completed - if state.entry_point == EntryPoint.FIT: - num_steps_completed += cast( - TEvalUnit, unit - ).eval_progress.num_steps_completed - snapshot_path = _get_snapshot_save_path( - self._dirpath, epoch, num_steps_completed - ) - self._checkpoint_impl( - state, - unit, - snapshot_path=snapshot_path, - intra_epoch=False, - prev_snapshot_wait=True, - curr_snapshot_wait=False, - ) - - def on_train_end(self, state: State, unit: TTrainUnit) -> None: - num_steps_completed = unit.train_progress.num_steps_completed - if state.entry_point == EntryPoint.FIT: - num_steps_completed += cast( - TEvalUnit, unit - ).eval_progress.num_steps_completed - epoch = unit.train_progress.num_epochs_completed - snapshot_path = _get_snapshot_save_path( - self._dirpath, epoch, num_steps_completed - ) - - fs = get_filesystem(snapshot_path) - if fs.exists(os.path.join(snapshot_path, SNAPSHOT_METADATA_FNAME)): - rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger) - return - - self._checkpoint_impl( - state, - unit, - snapshot_path=snapshot_path, - intra_epoch=False, - prev_snapshot_wait=True, - curr_snapshot_wait=True, - ) + super().on_train_start(state, unit) def on_exception( self, @@ -267,22 +135,25 @@ def _checkpoint_impl( state: State, unit: AppStateMixin, *, - snapshot_path: str, - intra_epoch: bool, - prev_snapshot_wait: bool, - curr_snapshot_wait: bool, - ) -> None: + checkpoint_path: str, + hook: str, + ) -> bool: """ Checkpoint the current state of the application. - - Args: - state: State of the application - unit: The training/evaluation/prediction unit - snapshot_path: Path to save the snapshot - intra_epoch: Whether in middle of epoch or not - prev_snapshot_wait: Whether to wait for previous snapshot to finish writing - curr_snapshot_wait: Whether to wait for current snapshot to finish writing """ + intra_epoch = False + prev_snapshot_wait = False + curr_snapshot_wait = False + if hook == "on_train_step_end": + intra_epoch = True + elif hook == "on_train_epoch_end": + prev_snapshot_wait = True + elif hook == "on_train_end": + prev_snapshot_wait = True + curr_snapshot_wait = True + else: + raise RuntimeError(f"Unexpected hook encountered '{hook}'") + app_state = _get_app_state( state, unit, @@ -297,39 +168,14 @@ def _checkpoint_impl( # future, add logic to set successful flag # only when checkpoint is fully written checkpoint_success = self._async_snapshot( - snapshot_path, app_state, wait=prev_snapshot_wait + checkpoint_path, app_state, wait=prev_snapshot_wait ) if curr_snapshot_wait: self._wait() else: - with get_timing_context( - state, f"{self.__class__.__name__}.take_sync_snapshot" - ): - checkpoint_success = self._sync_snapshot( - snapshot_path, app_state, wait=prev_snapshot_wait - ) - - # remove and book keep snapshots related to keep_last_n_checkpoints - if checkpoint_success: - if self._should_remove_snapshot(): - self._remove_snapshot(state) - self._ckpt_dirpaths.append(snapshot_path) - - def _should_remove_snapshot(self) -> bool: - keep_last_n_checkpoints = self._keep_last_n_checkpoints - return ( - keep_last_n_checkpoints is not None - and len(self._ckpt_dirpaths) >= keep_last_n_checkpoints - ) - - def _remove_snapshot(self, state: State) -> None: - # remove oldest snapshot directory - oldest_ckpt_path = self._ckpt_dirpaths.pop(0) - with get_timing_context(state, f"{self.__class__.__name__}.delete_snapshot"): - if self._pg_wrapper.get_rank() == 0: - # only delete on rank 0 - _delete_checkpoint(oldest_ckpt_path, SNAPSHOT_METADATA_FNAME) - self._pg_wrapper.barrier() + with get_timing_context(state, f"{self.__class__.__name__}.take_snapshot"): + checkpoint_success = self._sync_snapshot(checkpoint_path, app_state) + return checkpoint_success def _wait(self) -> None: if self._prev_snapshot is not None: @@ -366,7 +212,9 @@ def _async_snapshot( return True def _sync_snapshot( - self, snapshot_path: str, app_state: Dict[str, _TStateful], *, wait: bool + self, + snapshot_path: str, + app_state: Dict[str, _TStateful], ) -> bool: with _override_knobs(self._knob_options): rank_zero_info( @@ -467,52 +315,6 @@ def restore( snapshot.restore(app_state) rank_zero_info(f"Restored snapshot from path: {path}", logger=logger) - @staticmethod - def restore_from_latest( - dirpath: str, - unit: AppStateMixin, - *, - train_dataloader: Optional[Iterable[TTrainData]] = None, - process_group: Optional[dist.ProcessGroup] = None, - restore_options: Optional[RestoreOptions] = None, - storage_options: Optional[Dict[str, Any]] = None, - knob_options: Optional[KnobOptions] = None, - ) -> bool: - """ - Given a parent directory where checkpoints are saved, restore the snapshot state from the latest checkpoint in the directory. - - There are additional flags offered should the user want to skip loading the train and eval progress. - By default, the train and eval progress are restored, if applicable. - - Args: - dirpath: Parent directory from which to get the latest snapshot. - unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. - train_dataloader: An optional train dataloader to restore. - process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) - restore_options: Controls what to filter when restoring the state. - storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations. - knob_options: Additional keyword options for the snapshot knobs - - Returns: - True if the latest snapshot directory was found and successfully restored, otherwise False. - """ - path = get_latest_checkpoint_path( - dirpath, SNAPSHOT_METADATA_FNAME, process_group=process_group - ) - if path is None: - return False - logger.info(f"Restoring from path: {path}") - TorchSnapshotSaver.restore( - path, - unit, - train_dataloader=train_dataloader, - process_group=process_group, - restore_options=restore_options, - storage_options=storage_options, - knob_options=knob_options, - ) - return True - def _validate_snapshot_available() -> None: if not _TORCHSNAPSHOT_AVAILABLE: @@ -523,11 +325,6 @@ def _validate_snapshot_available() -> None: ) -def _get_snapshot_save_path(dirpath: str, epoch: int, step: int) -> str: - # TODO: discuss whether this path should be customized - return os.path.join(dirpath, f"epoch_{epoch}_step_{step}") - - def _app_state(unit: AppStateMixin) -> Dict[str, Any]: """Join together all of the tracked stateful entities to simplify registration of snapshottable states, deals with FSDP case""" app_state = unit.app_state()