From d3e85dc07ec469001388c791e2470ae266a7bfe8 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Wed, 28 Aug 2024 16:00:39 -0700 Subject: [PATCH] Remove support for deprecated DCP APIs in DCPSaver callback (#890) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/890 Reviewed By: anshulverma, JKSenthil Differential Revision: D61887203 fbshipit-source-id: 17bd899a9b88033feb0285f3a395be6edbf82d5a --- tests/framework/callbacks/test_dcp_saver.py | 11 +--- torchtnt/framework/callbacks/dcp_saver.py | 70 ++++++--------------- 2 files changed, 23 insertions(+), 58 deletions(-) diff --git a/tests/framework/callbacks/test_dcp_saver.py b/tests/framework/callbacks/test_dcp_saver.py index 2fbb6e0a5c..702fcd3dd5 100644 --- a/tests/framework/callbacks/test_dcp_saver.py +++ b/tests/framework/callbacks/test_dcp_saver.py @@ -7,18 +7,11 @@ # pyre-strict -import unittest - -from torchtnt.framework.callbacks.dcp_saver import _LATEST_DCP_AVAIL -from torchtnt.framework.state import State - -if not _LATEST_DCP_AVAIL: - raise unittest.SkipTest("Latest Pytorch is required to run DCP tests") - import math import os import shutil import tempfile +import unittest from typing import Any, Dict, Iterator, List, Optional from unittest import mock from unittest.mock import MagicMock, patch @@ -40,6 +33,8 @@ ) from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver + +from torchtnt.framework.state import State from torchtnt.framework.train import train from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import seed diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 2ab9191242..113bd3d8a7 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -16,6 +16,11 @@ import torch.distributed as dist from pyre_extensions import none_throws from torch.distributed import checkpoint as dcp + +from torch.distributed.checkpoint._fsspec_filesystem import ( + FsspecReader as Reader, + FsspecWriter as Writer, +) from torch.distributed.checkpoint.default_planner import ( DefaultLoadPlanner, DefaultSavePlanner, @@ -45,25 +50,8 @@ from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import MultiStateful, Stateful - logger: logging.Logger = logging.getLogger(__name__) -_LATEST_DCP_AVAIL: bool = True -try: - from torch.distributed.checkpoint._fsspec_filesystem import ( - FsspecReader as Reader, - FsspecWriter as Writer, - ) -except ModuleNotFoundError: - logger.warn( - "To use FsspecReader / FsspecWriter, please install latest pytorch version" - ) - _LATEST_DCP_AVAIL = False - from torch.distributed.checkpoint import ( - FileSystemReader as Reader, - FileSystemWriter as Writer, - ) - class DistributedCheckpointSaver(BaseCheckpointer): """ @@ -248,24 +236,13 @@ def _save( if storage_writer is None: storage_writer = Writer(checkpoint_id, **self.default_writer_options) - try: - dcp.save( - state_dict={"app_state": MultiStateful(app_state)}, - checkpoint_id=checkpoint_id, - process_group=self._process_group, - storage_writer=storage_writer, - planner=planner, - ) - except AttributeError as ex: - logger.warning( - f"Unable to save checkpoint (will retry saving using deprecated API). Error: {ex}" - ) - dcp.save_state_dict( - state_dict={"app_state": MultiStateful(app_state)}, - process_group=self._process_group, - storage_writer=storage_writer, - planner=planner, - ) + dcp.save( + state_dict={"app_state": MultiStateful(app_state)}, + checkpoint_id=checkpoint_id, + process_group=self._process_group, + storage_writer=storage_writer, + planner=planner, + ) return True @@ -397,21 +374,14 @@ def restore_with_id( if isinstance(optimizer, torch.optim.Optimizer): init_optim_state(optimizer) - try: - dcp.load( - {"app_state": MultiStateful(app_state)}, - checkpoint_id=checkpoint_id, - storage_reader=storage_reader, - planner=planner, - process_group=process_group, - ) - except AttributeError: - dcp.load_state_dict( - {"app_state": MultiStateful(app_state)}, - storage_reader=storage_reader, - process_group=process_group, - planner=planner, - ) + dcp.load( + {"app_state": MultiStateful(app_state)}, + checkpoint_id=checkpoint_id, + storage_reader=storage_reader, + planner=planner, + process_group=process_group, + ) + rank_zero_info( f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger )