From a8d0e0e9f1f59cb97e0295090297dae3943483f4 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 23 Sep 2024 15:35:06 -0400 Subject: [PATCH] remove legacy --- composer/utils/checkpoint.py | 26 +------------------------- tests/trainer/test_fsdp_checkpoint.py | 22 +--------------------- 2 files changed, 2 insertions(+), 46 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index b966c918c5..bee6d1c404 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -413,23 +413,6 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo ) + extra_suffix -def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, ObjectStore]], source_path: str): - if source_path.endswith('.symlink') or os.path.islink(source_path): - source_path = extract_path_from_symlink(source_path, object_store=object_store) - metadata_path = str(Path(source_path) / Path('.metadata')) - if object_store is None: - return not os.path.exists(metadata_path) - else: - try: - _, _, metadata_path = parse_uri(metadata_path) - with tempfile.TemporaryDirectory() as temp_dir: - metadata_destination = os.path.join(str(temp_dir), '.metadata') - download_object_or_file(metadata_path, metadata_destination, object_store) - return False - except FileNotFoundError: - return True - - def load_checkpoint( path: str, state: State, @@ -531,15 +514,8 @@ def load_checkpoint( :attr:`load_weights_only` is not None. Otherwise, None. """ path = partial_format(path, run_name=state.run_name) - using_legacy_sharded = False - if state.fsdp_sharded_state_dict_enabled: - assert object_store is None or isinstance( - object_store, - ObjectStore, - ), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore' - using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path) - if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded: + if state.fsdp_sharded_state_dict_enabled: rng_state_dicts = load_sharded_checkpoint( source_path=path, state=state, diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 5f97c6092c..f4d2ec16b9 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -31,7 +31,7 @@ from composer.optim import DecoupledAdamW from composer.trainer import Trainer from composer.utils import FSDPConfig, TPConfig, dist, parse_uri -from composer.utils.checkpoint import dist_cp_load, is_checkpoint_legacy_sharded +from composer.utils.checkpoint import dist_cp_load from composer.utils.file_helpers import get_file from composer.utils.object_store import S3ObjectStore from composer.utils.reproducibility import get_rng_state @@ -548,10 +548,6 @@ def test_fsdp_load_old_checkpoint( load_path_dir = (load_path_dir + 'ep0-ba2/') load_path = load_path_dir + f'ba2_rank{rank}.pt' - assert is_checkpoint_legacy_sharded( - object_store=S3ObjectStore(bucket=f'{s3_bucket}'), - source_path=load_path.lstrip(f's3://{s3_bucket}/'), - ) else: load_path = ( f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/' @@ -909,16 +905,9 @@ def test_fsdp_partitioned_state_dict_load( load_path = 's3://' + save_folder.strip('s3://').format( run_name=run_name, ) + ('/ba2' if not use_symlink else '/latest-rank0.pt.symlink') - object_store = S3ObjectStore(bucket=f'{s3_bucket}') else: - object_store = None load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2')) - assert not is_checkpoint_legacy_sharded( - object_store=object_store, - source_path=load_path.replace(f's3://{s3_bucket}/', ''), - ) - if autoresume: load_path = None trainer2 = get_trainer( @@ -1013,10 +1002,6 @@ def test_elastic_resumption( else: save_folder = None sharded_load_path = os.path.join(base_path, 'ba2') - assert not is_checkpoint_legacy_sharded( - object_store=S3ObjectStore(bucket=f'{s3_bucket}'), - source_path=sharded_load_path.replace(f's3://{s3_bucket}/', ''), - ) sharded_trainer = get_trainer( save_folder=save_folder, @@ -1237,11 +1222,6 @@ def set_up_planner( load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2')) - assert not is_checkpoint_legacy_sharded( - object_store=None, - source_path=load_path, - ) - trainer2 = get_trainer( save_folder=str(save_folder), load_path=load_path,