From 29632931896c3d0f05b396a55c70231c45163ad1 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 1 Oct 2024 13:39:54 -0400 Subject: [PATCH] Remove Legacy Checkpointing (#3631) --- composer/utils/checkpoint.py | 52 +++++---------------------- tests/trainer/test_fsdp_checkpoint.py | 27 +++----------- 2 files changed, 13 insertions(+), 66 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index ccefee4e60..9f480059d3 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -413,24 +413,6 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo ) + extra_suffix -def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, ObjectStore]], source_path: str): - if source_path.endswith('.symlink') or os.path.islink(source_path): - source_path = extract_path_from_symlink(source_path, object_store=object_store) - metadata_path = str(Path(source_path) / Path('.metadata')) - log.debug(f'Checking if checkpoint is legacy sharded by checking for metadata file at {metadata_path}.') - if object_store is None: - return not os.path.exists(metadata_path) - else: - try: - _, _, metadata_path = parse_uri(metadata_path) - with tempfile.TemporaryDirectory() as temp_dir: - metadata_destination = os.path.join(str(temp_dir), '.metadata') - download_object_or_file(metadata_path, metadata_destination, object_store) - return False - except FileNotFoundError: - return True - - def load_checkpoint( path: str, state: State, @@ -533,16 +515,8 @@ def load_checkpoint( """ path = partial_format(path, run_name=state.run_name) log.debug(f'Loading checkpoint from formatted path: {path}') - using_legacy_sharded = False + if state.fsdp_sharded_state_dict_enabled: - assert object_store is None or isinstance( - object_store, - ObjectStore, - ), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore' - using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path) - log.info(f'Using legacy sharded checkpoint: {using_legacy_sharded}') - - if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded: rng_state_dicts = load_sharded_checkpoint( source_path=path, state=state, @@ -557,26 +531,20 @@ def load_checkpoint( ) else: # Download the checkpoint to the node-local folder - log.debug('Loading checkpoint at %s', path) # Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node. # If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder. - needs_unique_checkpoint_folder = state.fsdp_sharded_state_dict_enabled or dist.get_local_rank() == 0 - tempdir_ctx = tempfile.TemporaryDirectory() if needs_unique_checkpoint_folder else contextlib.nullcontext(None) + tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None) with tempdir_ctx as tempdir: try: # Get the path to the proper checkpoint folder corresponding to the current rank's node. # If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir. - node_checkpoint_folder = ( - tempdir if state.fsdp_sharded_state_dict_enabled else _get_local_rank_zero_path(tempdir) - ) - assert node_checkpoint_folder is not None + node_checkpoint_folder = _get_local_rank_zero_path(tempdir) composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint( path=path, node_checkpoint_folder=node_checkpoint_folder, object_store=object_store, progress_bar=progress_bar, - fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled, deepspeed_sharded_checkpoint=is_model_deepspeed(state.model), ) rng_state_dicts = _restore_checkpoint( @@ -596,6 +564,8 @@ def load_checkpoint( # be a shared resource between nodes. dist.barrier() log.info('%s loaded from %s', 'Model weights' if load_weights_only else 'Trainer checkpoint', path) + + # Verify all ranks resumed on same step step_to_resume_from = state.timestamp.batch.value max_step_to_resume_from = state.device.tensor_to_device( torch.tensor(state.timestamp.batch.value, dtype=torch.int64), @@ -802,7 +772,6 @@ def download_checkpoint( node_checkpoint_folder: str, object_store: Optional[Union[ObjectStore, LoggerDestination]], progress_bar: bool, - fsdp_sharded_state_dict_enabled: bool = False, deepspeed_sharded_checkpoint: bool = False, ) -> tuple[str, Optional[str], bool]: """Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``. @@ -829,9 +798,7 @@ def download_checkpoint( # and only rank zero has this file unless fsdp_sharded_state_dict_enabled then # every rank has it's own file. extracted_checkpoint_folder = None - composer_states_filepath = ( - rank_n_checkpoint_filepath if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath - ) + composer_states_filepath = rank_zero_checkpoint_filepath if is_compressed_pt(path): original_path = path @@ -841,9 +808,8 @@ def download_checkpoint( with compressor.decompress(original_path) as in_file: shutil.copyfileobj(in_file, out_file) - checkpoint_is_sharded = fsdp_sharded_state_dict_enabled or deepspeed_sharded_checkpoint try: - if not checkpoint_is_sharded and dist.get_local_rank() == 0: + if not deepspeed_sharded_checkpoint and dist.get_local_rank() == 0: # If the checkpoint is not sharded, then local rank 0 on each node needs to download the # global rank 0 checkpoint path = _format_path_with_rank_zero(path) @@ -862,7 +828,7 @@ def download_checkpoint( # the underlying issue is that the checkpoint file does not exist on the disk # or could not be downloaded raise RuntimeError(f'Checkpoint {path} does not exist') - elif checkpoint_is_sharded: + elif deepspeed_sharded_checkpoint: # If the checkpoint is sharded, then every rank needs to download its own checkpoint path = _format_path_with_current_rank(path) try: @@ -892,7 +858,7 @@ def download_checkpoint( finally: # Use busy wait to avoid timeouts on large downloads for non-sharded checkpoints - if not checkpoint_is_sharded: + if not deepspeed_sharded_checkpoint: signal_file_path = os.path.join( node_checkpoint_folder, dist.get_node_signal_file_name(), diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index f9b40e2e78..9f785a94ff 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -31,7 +31,7 @@ from composer.optim import DecoupledAdamW from composer.trainer import Trainer from composer.utils import FSDPConfig, TPConfig, dist, parse_uri -from composer.utils.checkpoint import dist_cp_load, is_checkpoint_legacy_sharded +from composer.utils.checkpoint import dist_cp_load from composer.utils.file_helpers import get_file from composer.utils.object_store import S3ObjectStore from composer.utils.reproducibility import get_rng_state @@ -539,7 +539,8 @@ def test_fsdp_load_old_checkpoint( pytest.skip('Current torch version is older than torch version that checkpoint was written with.') if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']: - rank = 0 if state_dict_type == 'full' else '{rank}' + if state_dict_type == 'sharded': + pytest.skip('Loading legacy sharded checkpoints are not supported after v0.25.0.') load_path_dir = ( f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/' @@ -549,11 +550,7 @@ def test_fsdp_load_old_checkpoint( if ((version.parse(composer_version) > version.parse('0.15.0')) and state_dict_type != 'full'): load_path_dir = (load_path_dir + 'ep0-ba2/') - load_path = load_path_dir + f'ba2_rank{rank}.pt' - assert is_checkpoint_legacy_sharded( - object_store=S3ObjectStore(bucket=f'{s3_bucket}'), - source_path=load_path.lstrip(f's3://{s3_bucket}/'), - ) + load_path = load_path_dir + f'ba2_rank0.pt' else: load_path = ( f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/' @@ -911,16 +908,9 @@ def test_fsdp_partitioned_state_dict_load( load_path = 's3://' + save_folder.strip('s3://').format( run_name=run_name, ) + ('/ba2' if not use_symlink else '/latest-rank0.pt.symlink') - object_store = S3ObjectStore(bucket=f'{s3_bucket}') else: - object_store = None load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2')) - assert not is_checkpoint_legacy_sharded( - object_store=object_store, - source_path=load_path.replace(f's3://{s3_bucket}/', ''), - ) - if autoresume: load_path = None trainer2 = get_trainer( @@ -1015,10 +1005,6 @@ def test_elastic_resumption( else: save_folder = None sharded_load_path = os.path.join(base_path, 'ba2') - assert not is_checkpoint_legacy_sharded( - object_store=S3ObjectStore(bucket=f'{s3_bucket}'), - source_path=sharded_load_path.replace(f's3://{s3_bucket}/', ''), - ) sharded_trainer = get_trainer( save_folder=save_folder, @@ -1239,11 +1225,6 @@ def set_up_planner( load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2')) - assert not is_checkpoint_legacy_sharded( - object_store=None, - source_path=load_path, - ) - trainer2 = get_trainer( save_folder=str(save_folder), load_path=load_path,