Skip to content

Commit

Permalink
dcp checkpointer - ensure no distributed collectives while checkpoint…
Browse files Browse the repository at this point in the history
… is ongoing (#870)

Summary: Pull Request resolved: #870

Reviewed By: saumishr, diego-urgell

Differential Revision: D60174864

fbshipit-source-id: 69d15c0c889b766aae540c8dfeb5642d7e3ea339
  • Loading branch information
galrotem authored and facebook-github-bot committed Jul 24, 2024
1 parent 745f5cb commit e4e7a9d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
18 changes: 16 additions & 2 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import abc
import logging
from datetime import timedelta
from typing import Any, cast, Iterable, Literal, Optional, Union
from typing import Any, cast, Dict, Iterable, Literal, Optional, Union

import torch.distributed as dist
from pyre_extensions import none_throws
Expand Down Expand Up @@ -170,7 +170,7 @@ def _generate_checkpoint_and_upkeep(
value=metric_value,
)

checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
checkpoint_path = self._generate_checkpoint_path(
epoch,
step_mapping,
metric_data,
Expand Down Expand Up @@ -225,6 +225,20 @@ def _does_checkpoint_exist(
checkpoint_path, process_group
)

def _generate_checkpoint_path(
self,
epoch: int,
step_mapping: Union[int, Dict[Phase, int]],
metric_data: Optional[MetricData] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> CheckpointPath:
return self._checkpoint_manager.generate_checkpoint_path(
epoch,
step_mapping,
metric_data,
process_group=process_group,
)

def _get_tracked_metric_value(
self, unit: Union[TTrainUnit, TEvalUnit]
) -> Optional[float]:
Expand Down
25 changes: 24 additions & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath
from torchtnt.utils.checkpoint import (
BestCheckpointConfig,
CheckpointPath,
MetricData,
Phase,
)
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful, Stateful
Expand Down Expand Up @@ -385,6 +390,24 @@ def _does_checkpoint_exist(
checkpoint_path=checkpoint_path, process_group=process_group
)

def _generate_checkpoint_path(
self,
epoch: int,
step_mapping: Union[int, Dict[Phase, int]],
metric_data: Optional[MetricData] = None,
process_group: Optional[dist.ProcessGroup] = None,
) -> CheckpointPath:
# if we are still checkpointing, this might cause a collective hang.
# so wait here instead
self._wait()

return super()._generate_checkpoint_path(
epoch=epoch,
step_mapping=step_mapping,
metric_data=metric_data,
process_group=process_group,
)

@property
def default_writer_options(self) -> Dict[str, Any]:
# defaults are picked to to match TSS defaults
Expand Down

0 comments on commit e4e7a9d

Please sign in to comment.