diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 8bf4c3e9cc..d91e6aa791 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -9,7 +9,7 @@ import abc import logging from datetime import timedelta -from typing import Any, cast, Iterable, Literal, Optional, Union +from typing import Any, cast, Dict, Iterable, Literal, Optional, Union import torch.distributed as dist from pyre_extensions import none_throws @@ -170,7 +170,7 @@ def _generate_checkpoint_and_upkeep( value=metric_value, ) - checkpoint_path = self._checkpoint_manager.generate_checkpoint_path( + checkpoint_path = self._generate_checkpoint_path( epoch, step_mapping, metric_data, @@ -225,6 +225,20 @@ def _does_checkpoint_exist( checkpoint_path, process_group ) + def _generate_checkpoint_path( + self, + epoch: int, + step_mapping: Union[int, Dict[Phase, int]], + metric_data: Optional[MetricData] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> CheckpointPath: + return self._checkpoint_manager.generate_checkpoint_path( + epoch, + step_mapping, + metric_data, + process_group=process_group, + ) + def _get_tracked_metric_value( self, unit: Union[TTrainUnit, TEvalUnit] ) -> Optional[float]: diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 75dd7dc7c0..2f770fba8a 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -38,7 +38,12 @@ TTrainUnit, ) from torchtnt.framework.utils import get_timing_context -from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath +from torchtnt.utils.checkpoint import ( + BestCheckpointConfig, + CheckpointPath, + MetricData, + Phase, +) from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import MultiStateful, Stateful @@ -385,6 +390,24 @@ def _does_checkpoint_exist( checkpoint_path=checkpoint_path, process_group=process_group ) + def _generate_checkpoint_path( + self, + epoch: int, + step_mapping: Union[int, Dict[Phase, int]], + metric_data: Optional[MetricData] = None, + process_group: Optional[dist.ProcessGroup] = None, + ) -> CheckpointPath: + # if we are still checkpointing, this might cause a collective hang. + # so wait here instead + self._wait() + + return super()._generate_checkpoint_path( + epoch=epoch, + step_mapping=step_mapping, + metric_data=metric_data, + process_group=process_group, + ) + @property def default_writer_options(self) -> Dict[str, Any]: # defaults are picked to to match TSS defaults