diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c6b98079e6..8967e5e33c 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1885,26 +1885,31 @@ def _try_checkpoint_download( self, latest_checkpoint_path: str, save_latest_remote_file_name: str, + loggers: Sequence[Union[LoggerDestination, ObjectStore]], load_progress_bar: bool, ) -> None: """Attempts to download the checkpoint from the logger destinations.""" log.debug( f'Trying to download {save_latest_remote_file_name} to {latest_checkpoint_path} on rank {dist.get_global_rank()}', ) - if self._checkpoint_saver is None or self._checkpoint_saver.remote_uploader is None: - log.debug(f'Skip downloading from remote since no remote object_store found') - return - try: - get_file( - path=save_latest_remote_file_name, - destination=latest_checkpoint_path, - object_store=self._checkpoint_saver.remote_uploader.remote_backend, - overwrite=True, - progress_bar=load_progress_bar, - ) - except (FileNotFoundError): - log.info(f'Checkpoint not found in remote object store') - pass + remote_destination = list(loggers) + if self._checkpoint_saver is not None and self._checkpoint_saver.remote_uploader is not None: + remote_destination.append(self._checkpoint_saver.remote_uploader.remote_backend) + for logger in remote_destination: + try: + # Fetch from logger. If it succeeds, stop trying the rest of the loggers + get_file( + path=save_latest_remote_file_name, + destination=latest_checkpoint_path, + object_store=logger, + overwrite=True, + progress_bar=load_progress_bar, + ) + break + except (NotImplementedError, FileNotFoundError): + log.info(f'Checkpoint not found in: {logger}') + # Ignore errors caused by no checkpoint saved with logger + pass def _get_autoresume_checkpoint( self, @@ -1940,6 +1945,7 @@ def _get_autoresume_checkpoint( self._try_checkpoint_download( latest_checkpoint_path, save_latest_remote_file_name, + loggers, load_progress_bar, ) @@ -1974,6 +1980,7 @@ def _get_autoresume_checkpoint( self._try_checkpoint_download( latest_checkpoint_path, save_latest_remote_file_name, + loggers, load_progress_bar, )