Skip to content

Commit

Permalink
[Checkpoint] Fix symlink issue where symlink file uploaded before che…
Browse files Browse the repository at this point in the history
…ckpoint files upload (mosaicml#3376)

* a

* a

* a

* a

* a

* a

* a

* a

* fix test

* a

* a

* a

* a

* fix unit test

* a

* a

* a

* a

* a

* fix 2gpu unit test

* a

* a

* a

* a

* fix doctest

* a

* fix test and lint

* up

* a

* a

* a

* a

* a

* a

* a

* a

* address comments

* a

* a

* a

* a

* rerun test

* add logging

* remove debug comments

* comments

* a

* cleanup

* a

* linter

* lint

* Update composer/callbacks/checkpoint_saver.py

Co-authored-by: Evan Racah <[email protected]>

* commenst

* a

* fix test

* fix test

* comments

* a

---------

Co-authored-by: Evan Racah <[email protected]>
  • Loading branch information
bigning and eracah committed Jul 8, 2024
1 parent e08dba0 commit 84fa1fe
Show file tree
Hide file tree
Showing 12 changed files with 607 additions and 242 deletions.
179 changes: 157 additions & 22 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
FORMAT_NAME_WITH_DIST_AND_TIME_TABLE,
FORMAT_NAME_WITH_DIST_TABLE,
PartialFilePath,
RemoteFilesExistingCheckStatus,
RemoteUploader,
checkpoint,
create_interval_scheduler,
create_symlink_file,
Expand All @@ -28,6 +30,7 @@
format_name_with_dist,
format_name_with_dist_and_time,
is_model_deepspeed,
parse_uri,
partial_format,
)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME
Expand Down Expand Up @@ -287,8 +290,13 @@ def __init__(
num_checkpoints_to_keep: int = -1,
weights_only: bool = False,
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
num_concurrent_uploads: int = 1,
upload_timeout_in_seconds: int = 3600,
):
folder = str(folder)
backend, _, local_folder = parse_uri(str(folder))
if local_folder == '':
local_folder = '.'

filename = str(filename)
remote_file_name = str(remote_file_name) if remote_file_name is not None else None
latest_filename = str(latest_filename) if latest_filename is not None else None
Expand All @@ -304,10 +312,10 @@ def __init__(
self.save_interval = save_interval
self.last_checkpoint_batch: Optional[Time] = None

self.folder = folder
self.folder = local_folder

self.filename = PartialFilePath(filename.lstrip('/'), folder)
self.latest_filename = PartialFilePath(latest_filename.lstrip('/'), folder) if latest_filename else None
self.filename = PartialFilePath(filename.lstrip('/'), local_folder)
self.latest_filename = PartialFilePath(latest_filename.lstrip('/'), local_folder) if latest_filename else None
self.remote_file_name = PartialFilePath(remote_file_name) if remote_file_name else None
self.latest_remote_file_name = PartialFilePath(latest_remote_file_name) if latest_remote_file_name else None

Expand All @@ -320,6 +328,23 @@ def __init__(

self.start_batch = None

self.remote_uploader = None
self.rank_saves_symlinks: bool = False
self.tmp_dir_for_symlink = tempfile.TemporaryDirectory()
self.num_concurrent_uploads = num_concurrent_uploads
self.upload_timeout_in_seconds = upload_timeout_in_seconds
# Allow unit test to override this to make it faster
self._symlink_upload_wait_before_next_try_in_seconds = 30.0
self.pid = os.getpid()
self.symlink_count = 0
self.symlink_upload_tasks = []

if backend != '':
self.remote_uploader = RemoteUploader(
remote_folder=str(folder),
num_concurrent_uploads=self.num_concurrent_uploads,
)

def init(self, state: State, logger: Logger) -> None:
# If MLFlowLogger is being used, format MLFlow-specific placeholders in the save folder and paths.
# Assumes that MLFlowLogger comes before CheckpointSaver in the list of loggers.
Expand All @@ -346,9 +371,10 @@ def init(self, state: State, logger: Logger) -> None:
self.latest_remote_file_name.filename,
**mlflow_format_kwargs,
)

break

if self.remote_uploader is not None:
self.remote_uploader.init()
folder = format_name_with_dist(self.folder, state.run_name)
os.makedirs(folder, exist_ok=True)

Expand Down Expand Up @@ -410,6 +436,27 @@ def load_state_dict(self, state: dict[str, Any]):
load_timestamp.load_state_dict(timestamp_state)
self.all_saved_checkpoints_to_timestamp[save_filename] = load_timestamp

def _upload_checkpoint(
self,
remote_file_name: str,
local_file_name: str,
local_remote_file_names: list[str],
logger: Logger,
):
if self.remote_uploader is not None:
self.remote_uploader.upload_file_async(
remote_file_name=remote_file_name,
file_path=pathlib.Path(local_file_name),
overwrite=self.overwrite,
)
local_remote_file_names.append(remote_file_name)
else:
logger.upload_file(
remote_file_name=remote_file_name,
file_path=local_file_name,
overwrite=self.overwrite,
)

def _save_checkpoint(self, state: State, logger: Logger):
self.last_checkpoint_batch = state.timestamp.batch

Expand All @@ -432,7 +479,14 @@ def _save_checkpoint(self, state: State, logger: Logger):
)
log.debug(f'Checkpoint locally saved to {saved_path}')

self.symlink_count += 1
# Remote checkpoint file names on this rank
local_remote_file_names = []
all_remote_filenames = []

if not saved_path: # not all ranks save
if self.remote_file_name is not None and self.remote_uploader is not None:
all_remote_filenames = dist.all_gather_object(local_remote_file_names)
return

metadata_local_file_path = None
Expand All @@ -443,6 +497,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
state.timestamp,
)

self.rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
if self.latest_filename is not None and self.num_checkpoints_to_keep != 0:
symlink = self.latest_filename.format(state, is_deepspeed)
os.makedirs(os.path.dirname(symlink), exist_ok=True)
Expand All @@ -455,8 +510,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
src_path = str(pathlib.Path(saved_path).parent)
else:
src_path = saved_path
this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
if this_rank_saves_symlinks:
if self.rank_saves_symlinks:
os.symlink(os.path.relpath(src_path, os.path.dirname(symlink)), symlink)

# if remote file name provided, upload the checkpoint
Expand All @@ -482,10 +536,11 @@ def _save_checkpoint(self, state: State, logger: Logger):
state.timestamp,
)
assert metadata_local_file_path is not None
logger.upload_file(
self._upload_checkpoint(
remote_file_name=metadata_remote_file_name,
file_path=metadata_local_file_path,
overwrite=self.overwrite,
local_file_name=metadata_local_file_path,
local_remote_file_names=local_remote_file_names,
logger=logger,
)
else:
remote_file_name = self.remote_file_name.format(
Expand All @@ -495,12 +550,20 @@ def _save_checkpoint(self, state: State, logger: Logger):

log.debug(f'Uploading checkpoint to {remote_file_name}')
try:
logger.upload_file(remote_file_name=remote_file_name, file_path=saved_path, overwrite=self.overwrite)
self._upload_checkpoint(
remote_file_name=remote_file_name,
local_file_name=saved_path,
local_remote_file_names=local_remote_file_names,
logger=logger,
)
except FileExistsError as e:
raise FileExistsError(
f'Uploading checkpoint failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite checkpoints with Trainer, set save_overwrite to True.',
) from e

if self.remote_uploader is not None:
all_remote_filenames = dist.all_gather_object(local_remote_file_names)

# symlinks stay the same with sharded checkpointing
if self.latest_remote_file_name is not None:
symlink_name = self.latest_remote_file_name.format(
Expand All @@ -509,17 +572,31 @@ def _save_checkpoint(self, state: State, logger: Logger):
).lstrip('/') + '.symlink'

# create and upload a symlink file
with tempfile.TemporaryDirectory() as tmpdir:
symlink_filename = os.path.join(tmpdir, 'latest.symlink')
# Sharded checkpoints for torch >2.0 use directories not files for load_paths
if state.fsdp_sharded_state_dict_enabled:
src_path = str(pathlib.Path(remote_file_name).parent)
symlink_filename = os.path.join(
self.tmp_dir_for_symlink.name,
f'latest.{self.symlink_count}.symlink',
)
# Sharded checkpoints for torch >2.0 use directories not files for load_paths
if state.fsdp_sharded_state_dict_enabled:
src_path = str(pathlib.Path(remote_file_name).parent)
else:
src_path = remote_file_name
log.debug(f'Creating symlink file {symlink_filename} -> {src_path}')
if self.rank_saves_symlinks:
create_symlink_file(src_path, symlink_filename)
if self.remote_uploader is not None:
remote_checkpoint_file_names = []
for file_names in all_remote_filenames:
remote_checkpoint_file_names += file_names
check_remote_files_exist_future = self.remote_uploader.check_remote_files_exist_async(
remote_checkpoint_file_names=remote_checkpoint_file_names,
max_wait_time_in_seconds=self.upload_timeout_in_seconds,
wait_before_next_try_in_seconds=self._symlink_upload_wait_before_next_try_in_seconds,
)
self.symlink_upload_tasks.append(
(check_remote_files_exist_future, symlink_filename, symlink_name),
)
else:
src_path = remote_file_name
log.debug(f'Creating symlink file {symlink_filename} -> {src_path}')
this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
if this_rank_saves_symlinks:
create_symlink_file(src_path, symlink_filename)
logger.upload_file(
remote_file_name=symlink_name,
file_path=symlink_filename,
Expand All @@ -532,7 +609,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
self._rotate_checkpoints(sharding_enabled=state.fsdp_sharded_state_dict_enabled)

def _rotate_checkpoints(self, sharding_enabled: bool = False):

while len(self.saved_checkpoints) > self.num_checkpoints_to_keep:
prefix_dir = None
checkpoint_to_delete = self.saved_checkpoints.pop(0)
Expand All @@ -542,3 +618,62 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False):
else:
if dist.get_global_rank() == 0:
shutil.rmtree(prefix_dir)

def batch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
if self.remote_uploader is None:
return
self.remote_uploader.check_workers()
if not self.rank_saves_symlinks:
return
undone_symlink_upload_tasks = []
for (check_remote_files_exist_future, local_symlink_file,
remote_symlink_file) in reversed(self.symlink_upload_tasks):
if not check_remote_files_exist_future.done():
undone_symlink_upload_tasks.insert(
0,
(check_remote_files_exist_future, local_symlink_file, remote_symlink_file),
)
continue
if check_remote_files_exist_future.done():
result = check_remote_files_exist_future.result()
if result == RemoteFilesExistingCheckStatus.EXIST:
self.remote_uploader.upload_file_async(
remote_file_name=remote_symlink_file,
file_path=local_symlink_file,
overwrite=True,
)
break
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
self.symlink_upload_tasks = undone_symlink_upload_tasks

def fit_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
if self.remote_uploader is None:
return
log.info('Waiting for checkpoint uploading to finish')
self.remote_uploader.wait()
if self.rank_saves_symlinks and len(self.symlink_upload_tasks) > 0:
log.debug('Uploading symlink to the latest checkpoint')
# We only need to upload a symlink pointing to the latest checkpoint files, so we can ignore successful uploads of older checkpoints.
check_remote_files_exist_future, local_symlink_file, remote_symlink_file = self.symlink_upload_tasks[-1]
result = check_remote_files_exist_future.result()
if result == RemoteFilesExistingCheckStatus.EXIST:
symlink_upload_future = self.remote_uploader.upload_file_async(
remote_file_name=remote_symlink_file,
file_path=local_symlink_file,
overwrite=True,
)
symlink_upload_future.result()
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
log.info('Checkpoint uploading finished!')

def post_close(self):
if self.remote_uploader is not None:
# Wait the symlink file upload to finish and close remote uploader
try:
self.remote_uploader.wait_and_close()
except Exception as e:
log.error(f'RemoteUploader run into exception {e}')
Loading

0 comments on commit 84fa1fe

Please sign in to comment.