From c97a632771489b745db087e96a87dcdb03566674 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Mon, 23 Sep 2024 15:04:30 -0700 Subject: [PATCH] Swap DCP restore ad-hoc gloo pg creation with context manager (#903) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/903 Differential Revision: D63268179 --- torchtnt/framework/callbacks/dcp_saver.py | 38 ++++++----------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index b1b4a232a5..14e99a1677 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -9,7 +9,6 @@ import logging import time from concurrent.futures import Future -from datetime import timedelta from typing import Any, Dict, Iterable, List, Optional, Union import torch.distributed as dist @@ -45,6 +44,7 @@ ) from torchtnt.framework.utils import get_timing_context from torchtnt.utils.checkpoint import BestCheckpointConfig +from torchtnt.utils.distributed import get_or_create_gloo_pg from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import MultiStateful, Stateful @@ -271,23 +271,6 @@ def restore_with_id( storage_reader: Instance of StorageReader used to perform reads. If this is not specified, it will automatically infer the reader based on the checkpoint_id. If checkpoint_id is also None, an exception will be raised. (Default: ``None``) """ - - # use gloo pg if available - gloo_pg_created = False - if dist.is_initialized(): - pg = dist.group.WORLD if process_group is None else process_group - - if dist.get_backend(pg) != dist.Backend.GLOO: - rank_zero_info( - "Creating new gloo process group for loading checkpoint." - ) - pg = dist.new_group( - timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO - ) - gloo_pg_created = True - else: - pg = process_group - restore_options = restore_options or RestoreOptions() app_state = _prepare_app_state_for_restore(unit, restore_options) checkpoint_id = str(checkpoint_id) @@ -321,22 +304,19 @@ def restore_with_id( "train_dataloader was passed to `restore` but no train dataloader exists in the Snapshot" ) - dcp.load( - {"app_state": MultiStateful(app_state)}, - checkpoint_id=checkpoint_id, - storage_reader=storage_reader, - planner=planner, - process_group=pg, - ) + with get_or_create_gloo_pg(candidate_pg=process_group) as pg: + dcp.load( + {"app_state": MultiStateful(app_state)}, + checkpoint_id=checkpoint_id, + storage_reader=storage_reader, + planner=planner, + process_group=pg, + ) rank_zero_info( f"Restored snapshot for checkpoint_id: {checkpoint_id}", logger=logger ) - # destroy gloo pg if created, its sole purpose was for checkpoint restore - if gloo_pg_created: - dist.destroy_process_group(pg) - def _generate_checkpoint_and_upkeep( self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str ) -> bool: