Skip to content

Commit

Permalink
a
Browse files Browse the repository at this point in the history
  • Loading branch information
bigning committed Jun 14, 2024
1 parent 1280266 commit c0cb94d
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1885,26 +1885,31 @@ def _try_checkpoint_download(
self,
latest_checkpoint_path: str,
save_latest_remote_file_name: str,
loggers: Sequence[Union[LoggerDestination, ObjectStore]],
load_progress_bar: bool,
) -> None:
"""Attempts to download the checkpoint from the logger destinations."""
log.debug(
f'Trying to download {save_latest_remote_file_name} to {latest_checkpoint_path} on rank {dist.get_global_rank()}',
)
if self._checkpoint_saver is None or self._checkpoint_saver.remote_uploader is None:
log.debug(f'Skip downloading from remote since no remote object_store found')
return
try:
get_file(
path=save_latest_remote_file_name,
destination=latest_checkpoint_path,
object_store=self._checkpoint_saver.remote_uploader.remote_backend,
overwrite=True,
progress_bar=load_progress_bar,
)
except (FileNotFoundError):
log.info(f'Checkpoint not found in remote object store')
pass
remote_destination = list(loggers)
if self._checkpoint_saver is not None and self._checkpoint_saver.remote_uploader is not None:
remote_destination.append(self._checkpoint_saver.remote_uploader.remote_backend)
for logger in remote_destination:
try:
# Fetch from logger. If it succeeds, stop trying the rest of the loggers
get_file(
path=save_latest_remote_file_name,
destination=latest_checkpoint_path,
object_store=logger,
overwrite=True,
progress_bar=load_progress_bar,
)
break
except (NotImplementedError, FileNotFoundError):
log.info(f'Checkpoint not found in: {logger}')
# Ignore errors caused by no checkpoint saved with logger
pass

def _get_autoresume_checkpoint(
self,
Expand Down Expand Up @@ -1940,6 +1945,7 @@ def _get_autoresume_checkpoint(
self._try_checkpoint_download(
latest_checkpoint_path,
save_latest_remote_file_name,
loggers,
load_progress_bar,
)

Expand Down Expand Up @@ -1974,6 +1980,7 @@ def _get_autoresume_checkpoint(
self._try_checkpoint_download(
latest_checkpoint_path,
save_latest_remote_file_name,
loggers,
load_progress_bar,
)

Expand Down

0 comments on commit c0cb94d

Please sign in to comment.