From 3c0a8179d2c3d60786226a73d8f935bc46306a31 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Thu, 30 May 2024 22:07:52 -0700 Subject: [PATCH] Raise errors on all ranks for checkpoint download failures (#3345) Co-authored-by: Ning Wang --- composer/utils/checkpoint.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 19ca1cd490..2d176b135a 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -248,6 +248,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): all_file_paths = dist.all_gather_object(relative_file_paths) # 2. Download to the destination all files this rank needs if on first replica + download_error = False if first_replica: log.debug(f'Rank {dist.get_global_rank()} starting to download files.') @@ -275,12 +276,26 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): download_object_or_file(object_name, file_destination, self.object_store) log.debug(f'Finished downloading {relative_file_path} to {file_destination}.') except Exception as e: - # PyTorch will capture any exception of this function, - # and dist.all_gather_objects(exception) before raising it. - # If that all_gather_objects fails, the exception is never visible to user. - # We immediately print the exception to avoid that situation. log.error(f'Exception {type(e)} raised during downloading: {str(e)}') - raise e + download_error = True + + # PyTorch will capture any exception of this function, + # and dist.all_gather_objects(exception) before raising it. + # If that all_gather_objects fails, the exception is never visible to user. + # We raise the exception from all ranks to ensure the user sees it. + download_error_tensor = dist.get_device(None).tensor_to_device(torch.tensor(1 if download_error else 0)) + error_by_rank = dist.all_gather(download_error_tensor) + failed_ranks = [] + for rank, error in enumerate(list(error_by_rank)): + if error > 0: + failed_ranks.append(rank) + download_error = True + + if download_error: + raise RuntimeError( + f'Ranks {failed_ranks} failed to download.', + 'To see the full error please look at the logs for that rank, which are logged via log.error.', + ) # 3. Wait for all ranks to finish. log.debug(f'Rank {dist.get_global_rank()} finished downloading all files.')