Skip to content

Commit

Permalink
Merge branch 'dev' into batch_code_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Feb 10, 2024
2 parents b6dc973 + 9bb32bc commit 9069663
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 61 deletions.
4 changes: 2 additions & 2 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def patch_pytorch():
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Monkeypath state_dict
# Monkeypatch state_dict
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p3p0
FullyShardedDataParallel.__init__ = init_fn_t2p3p0

# Monkeypath state_dict
# Monkeypatch state_dict
from torch.distributed.checkpoint import state_dict # type: ignore

from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p3p0
Expand Down
162 changes: 116 additions & 46 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from packaging import version
from torch.distributed import checkpoint as dist_cp
from torch.distributed._tensor import DeviceMesh
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
Expand Down Expand Up @@ -135,6 +136,16 @@ def _get_write_mode(name: str) -> str:
raise ValueError(f'{name} does not end with a valid tarfile extension.')


def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
if 'rng' in field_name:
_, rng_rank_index, _ = field_value
rng_inds.append(rng_rank_index)
rng_inds = set(rng_inds)
return len(rng_inds)


class FileSystemReaderWithValidation(dist_cp.FileSystemReader):
"""FileSystemReader that validates checkpoint files prior to reading."""

Expand Down Expand Up @@ -172,17 +183,28 @@ def read_metadata(self) -> Metadata:
# A subclass of FileSystemReaderWithValidation that downloads files from the object store before reading them from the local filesystem.
class DistCPObjectStoreReader(FileSystemReaderWithValidation):

def __init__(self, source_path: str, destination_path: str, object_store):
def __init__(self, source_path: str, destination_path: str, object_store: Union[ObjectStore, LoggerDestination],
device_mesh: Optional[DeviceMesh]):
self.source_path = source_path
self.destination_path = destination_path
self.object_store = object_store
self.device_mesh = device_mesh

# Download metadata file.
Path(self.destination_path).mkdir(parents=True, exist_ok=True)
metadata_destination = os.path.join(self.destination_path, '.metadata')
if dist.get_local_rank() == 0:
object_store.download_object(object_name=str(Path(source_path) / Path('.metadata')),
filename=metadata_destination)
metadata_path = str(Path(source_path) / Path('.metadata'))
if isinstance(object_store, ObjectStore):
object_store.download_object(
object_name=metadata_path,
filename=metadata_destination,
)
else:
object_store.download_file(
remote_file_name=metadata_path,
destination=metadata_destination,
)
dist.barrier()

# FileSystemReader takes in a root directory in its constructor, which is the dir where
Expand All @@ -191,22 +213,80 @@ def __init__(self, source_path: str, destination_path: str, object_store):
super().__init__(destination_path)

def read_data(self, plan: LoadPlan, planner: LoadPlanner):
# 1. Download to the destination all files that this rank is responsible for.
for plan_item in plan.items:
# Each plan item has a storage index which points to the relative path of the shard file at save time.
relative_file_path = self.storage_data[plan_item.storage_index].relative_path
# Download the shard file to the relative path it's associated to and save that relative path
# to the root directory specified to the FileSystem reader constructor.
file_destination = str(Path(self.destination_path) / Path(relative_file_path))
# The file could have already been downloaded as diffeent plan items can point to same file.
if not os.path.exists(file_destination):
self.object_store.download_object(object_name=str(Path(self.source_path) / Path(relative_file_path)),
filename=file_destination)
first_replica = self.device_mesh is None or self.device_mesh.get_local_rank(mesh_dim=0) == 0

# 1. Download to the destination all files this rank needs if on first replica
if first_replica:
log.debug(f'Rank {dist.get_global_rank()} starting to download files.')
for plan_item in plan.items:
# Each plan item has a storage index which points to the relative path of the shard file at save time.
relative_file_path = self.storage_data[plan_item.storage_index].relative_path
# Download the shard file to the relative path it's associated to and save that relative path
# to the root directory specified to the FileSystem reader constructor.
file_destination = str(Path(self.destination_path) / Path(relative_file_path))
# The file could have already been downloaded as different plan items can point to same file.
if not os.path.exists(file_destination):
log.debug(f'Downloading {relative_file_path} to {file_destination}.')
object_name = str(Path(self.source_path) / Path(relative_file_path))
if isinstance(self.object_store, ObjectStore):
self.object_store.download_object(
object_name=object_name,
filename=file_destination,
)
else:
self.object_store.download_file(
remote_file_name=object_name,
destination=file_destination,
)
log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')

# 2. Wait for all ranks to finish.
log.debug(f'Rank {dist.get_global_rank()} finished downloading all files.')
dist.barrier()
log.debug('Done waiting for all ranks to finish downloading files.')

# 3. Broadcast files to all other replicas if HSDP
if self.device_mesh is not None and self.device_mesh.ndim == 2:
# Broadcast file to all replicas
replicate_process_group = self.device_mesh.get_group(0)
shard_size = self.device_mesh.size(1)
rank_in_first_replica = dist.get_global_rank() % shard_size
sender = dist.get_global_rank() == rank_in_first_replica
receiver = dist.get_global_rank() != rank_in_first_replica

# Send list of files to all ranks
file_list = [sorted(os.listdir(self.destination_path))]
dist.broadcast_object_list(file_list, src=rank_in_first_replica, group=replicate_process_group)
file_list = file_list[0]
log.debug(f'List of files to broadcast: {file_list}')

# Send each file to the appropriate rank
for file_name in file_list:
if 'metadata' in file_name: # All ranks already have the metadata file
continue
if dist.get_local_rank() == 0: # Only 1 rank per node needs to transfer file
full_path = os.path.join(self.destination_path, file_name)
log.debug(f'Transferring {full_path=}')
file_object = [None]
if sender:
with open(full_path, 'rb') as f:
file_object = [{'content': f.read()}]
dist.broadcast_object_list(file_object,
src=dist.get_global_rank() % shard_size,
group=replicate_process_group)
received_file_object = file_object[0]
assert received_file_object is not None
if receiver and not os.path.exists(full_path):
with open(full_path, 'wb') as f:
f.write(received_file_object['content'])

log.debug(f'Rank {dist.get_global_rank()} finished transferring files to all ranks.')
dist.barrier()
log.debug(
f'Done waiting for all ranks to finish transferring files. Local checkpoint files: {os.listdir(self.destination_path)}'
)

# 3. Piggyback off of the FileSystemReader to read all the files now that they are downloaded.
# 4. Piggyback off of the FileSystemReader to read all the files now that they are downloaded.
return super().read_data(plan, planner)


Expand Down Expand Up @@ -249,7 +329,16 @@ def is_checkpoint_legacy_sharded(object_store: Optional[ObjectStore], source_pat
try:
with tempfile.TemporaryDirectory() as temp_dir:
metadata_destination = os.path.join(str(temp_dir), '.metadata')
object_store.download_object(object_name=metadata_path, filename=metadata_destination)
if isinstance(object_store, ObjectStore):
object_store.download_object(
object_name=metadata_path,
filename=metadata_destination,
)
else:
object_store.download_file(
remote_file_name=metadata_path,
destination=metadata_destination,
)
return False
except FileNotFoundError:
return True
Expand Down Expand Up @@ -459,15 +548,6 @@ def load_sharded_checkpoint(
load_planner = state.fsdp_config['load_planner']
_validate_load_planner(load_planner)

def _get_num_ranks_that_saved_rng(metadata: Metadata):
rng_inds = []
for field_name, field_value in metadata.planner_data.items():
if 'rng' in field_name:
_, rng_rank_index, _ = field_value
rng_inds.append(rng_rank_index)
rng_inds = set(rng_inds)
return len(rng_inds)

# Check to make sure source_path is a directory.
if object_store is None:
if os.path.islink(source_path):
Expand All @@ -485,10 +565,12 @@ def _get_num_ranks_that_saved_rng(metadata: Metadata):
# Get the tempfile made on local rank 0.
local_rank0_index = dist.get_global_rank() - dist.get_local_rank()
rank0_download_tempdir = str(dist.all_gather_object(temp_download_dir)[local_rank0_index])
storage_reader = DistCPObjectStoreReader(source_path=source_path,
destination_path=str(
Path(rank0_download_tempdir) / Path('checkpoints')),
object_store=object_store)
storage_reader = DistCPObjectStoreReader(
source_path=source_path,
destination_path=str(Path(rank0_download_tempdir) / Path('checkpoints')),
object_store=object_store,
device_mesh=state.fsdp_device_mesh,
)
else:
storage_reader = FileSystemReaderWithValidation(source_path)

Expand Down Expand Up @@ -517,34 +599,22 @@ def _get_num_ranks_that_saved_rng(metadata: Metadata):
# Ensure state exists
state_dict['state'] = state_dict.get('state', {})

# Only some ranks are meant to load checkpoint
expect_file = False
process_group = None
device_mesh = state.fsdp_device_mesh
if device_mesh is not None and device_mesh.ndim == 2:
# If hybrid shard, only rank in first replica saves
expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0)
if expect_file:
process_group = device_mesh.get_group(1) # Shard process_group for first replica
log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}')
else:
expect_file = True

if version.parse(torch.__version__) > version.parse('2.2.9'):
dist_cp.load( # type: ignore
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
process_group=process_group,
no_dist=(not dist.is_initialized()),
)
else:
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=load_planner,
process_group=process_group,
no_dist=(not dist.is_initialized()),
)

log.info(f'Loaded state dict')
state.load_state_dict(
state_dict['state'],
logger,
Expand Down Expand Up @@ -1004,10 +1074,10 @@ def _save_checkpoint(
device_mesh = state.fsdp_device_mesh
if device_mesh is not None and device_mesh.ndim == 2:
# If hybrid shard, only rank in first replica saves
expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0)
expect_file = device_mesh.get_local_rank(mesh_dim=0) == 0
if expect_file:
process_group = device_mesh.get_group(1) # Shard process_group for first replica
log.debug(f'global_rank={dist.get_global_rank()}, {expect_file=}')
log.debug(f'Saving on global_rank={dist.get_global_rank()}, {expect_file=}')
else:
expect_file = True

Expand Down
Loading

0 comments on commit 9069663

Please sign in to comment.