Skip to content

Commit

Permalink
Allow mutliple metadata file names in checkpointers (#872)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #872

Reviewed By: galrotem

Differential Revision: D60246320

fbshipit-source-id: 24be55bcf6917a9c0b2eb5c539d707c843c9fbbb
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Jul 30, 2024
1 parent e125ae9 commit 5c73bd5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
16 changes: 9 additions & 7 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import abc
import logging
from datetime import timedelta
from typing import Any, cast, Dict, Iterable, Literal, Optional, Union
from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union

import torch.distributed as dist
from pyre_extensions import none_throws
Expand Down Expand Up @@ -43,7 +43,8 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
2) ``restore`` which implements restoring the checkpoint given the relevant checkpoint path.
The subclass may override the ``metadata_fname`` attribute to specify the filename of the metadata file that will be written within the checkpoint directory.
This will be used by this base class to ensure the integrity of the checkpoint.
This will be used by this base class to ensure the integrity of the checkpoint. This is a list because some checkpointers may allow more than one valid
``metadata_fnames``, depending on storage or optimization configurations.
Args:
dirpath: Parent directory to save checkpoints to.
Expand All @@ -67,7 +68,8 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
checkpoint will be saved, without the metric value in the checkpoint name
"""

metadata_fname: Optional[str] = None
# No metadata file is checked by default. This can be overridden by subclasses.
metadata_fnames: List[str] = []

def __init__(
self,
Expand Down Expand Up @@ -112,7 +114,7 @@ def __init__(
dirpath,
best_checkpoint_config,
keep_last_n_checkpoints,
metadata_fnames=[self.metadata_fname] if self.metadata_fname else None,
metadata_fnames=self.metadata_fnames,
process_group=self._process_group,
)

Expand Down Expand Up @@ -385,11 +387,11 @@ def restore_from_latest(
True if the latest checkpoint directory was found and successfully restored, otherwise False.
"""
path = get_latest_checkpoint_path(
dirpath, metadata_fname=cls.metadata_fname, process_group=process_group
dirpath, metadata_fname=cls.metadata_fnames, process_group=process_group
)
if path is None:
logger.info(
f"Attempted to restore from the following path but no checkpoint was found: {dirpath=}, {cls.metadata_fname}"
f"Attempted to restore from the following path but no checkpoint was found: {dirpath=}, {cls.metadata_fnames}"
)
return False
logger.info(f"Restoring from path: {path}")
Expand Down Expand Up @@ -438,7 +440,7 @@ def restore_from_best(
dirpath,
metric_name=metric_name,
mode=mode,
metadata_fname=cls.metadata_fname,
metadata_fname=cls.metadata_fnames,
process_group=process_group,
)

Expand Down
4 changes: 2 additions & 2 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import logging
import time
from concurrent.futures import Future
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -102,7 +102,7 @@ class DistributedCheckpointSaver(BaseCheckpointer):
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
"""

metadata_fname: Optional[str] = ".metadata"
metadata_fnames: List[str] = [".metadata"]

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class TorchSnapshotSaver(BaseCheckpointer):
appropriately. For example, if logging validation accuracy, the unit must be responsible for maintaining the value and resetting it when the epoch ends.
"""

metadata_fname: Optional[str] = ".snapshot_metadata"
metadata_fnames: List[str] = [".snapshot_metadata"]

def __init__(
self,
Expand Down

0 comments on commit 5c73bd5

Please sign in to comment.