Skip to content

Commit

Permalink
Race condition fix in checkpoint loading util (#3001)
Browse files Browse the repository at this point in the history
* HSDP fix race condition

* bug fix

* clean up

* update comments

* Update composer/utils/checkpoint.py

Co-authored-by: Mihir Patel <[email protected]>

* bug fix

* linter

* Update composer/utils/checkpoint.py

Co-authored-by: Mihir Patel <[email protected]>

* Update composer/utils/checkpoint.py

Co-authored-by: Mihir Patel <[email protected]>

* Update composer/utils/checkpoint.py

Co-authored-by: Mihir Patel <[email protected]>

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
jessechancy and mvpatel2000 authored Feb 13, 2024
1 parent 157af10 commit 30e6525
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,33 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
first_replica = self.device_mesh is None or self.device_mesh.ndim == 1 or (
self.device_mesh.ndim >= 2 and self.device_mesh.get_local_rank(mesh_dim=0) == 0)

# 1. Download to the destination all files this rank needs if on first replica
# 1. Collect the relative paths to download for all ranks for deduplication
relative_file_paths = set()
for plan_item in plan.items:
relative_file_paths.add(self.storage_data[plan_item.storage_index].relative_path)
all_file_paths = dist.all_gather_object(relative_file_paths)

# 2. Download to the destination all files this rank needs if on first replica
if first_replica:
log.debug(f'Rank {dist.get_global_rank()} starting to download files.')

# Get the lowest rank in the current node
local_rank_0 = dist.get_global_rank() - dist.get_local_rank()

for plan_item in plan.items:
# Each plan item has a storage index which points to the relative path of the shard file at save time.
relative_file_path = self.storage_data[plan_item.storage_index].relative_path
# Check if the file is scheduled to be downloaded by a lower rank on the same node
# i.e. if rank 0 and rank 1 on the same node have the same the same required file,
# only rank 0 should download it and not rank 1.
is_downloaded = any(
relative_file_path in all_file_paths[i] for i in range(local_rank_0, dist.get_global_rank()))

# Download the shard file to the relative path it's associated to and save that relative path
# to the root directory specified to the FileSystem reader constructor.
file_destination = str(Path(self.destination_path) / Path(relative_file_path))

# The file could have already been downloaded as different plan items can point to same file.
if not os.path.exists(file_destination):
if not is_downloaded and not os.path.exists(file_destination):
log.debug(f'Downloading {relative_file_path} to {file_destination}.')
object_name = str(Path(self.source_path) / Path(relative_file_path))
if isinstance(self.object_store, ObjectStore):
Expand All @@ -242,12 +258,12 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
)
log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')

# 2. Wait for all ranks to finish.
# 3. Wait for all ranks to finish.
log.debug(f'Rank {dist.get_global_rank()} finished downloading all files.')
dist.barrier()
log.debug('Done waiting for all ranks to finish downloading files.')

# 3. Broadcast files to all other replicas if HSDP
# 4. Broadcast files to all other replicas if HSDP
if self.device_mesh is not None and self.device_mesh.ndim == 2:
# Broadcast file to all replicas
replicate_process_group = self.device_mesh.get_group(0)
Expand Down Expand Up @@ -288,7 +304,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
f'Done waiting for all ranks to finish transferring files. Local checkpoint files: {os.listdir(self.destination_path)}'
)

# 4. Piggyback off of the FileSystemReader to read all the files now that they are downloaded.
# 5. Piggyback off of the FileSystemReader to read all the files now that they are downloaded.
return super().read_data(plan, planner)


Expand Down

0 comments on commit 30e6525

Please sign in to comment.