Skip to content

Commit

Permalink
remove legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Sep 23, 2024
1 parent d2e1d5e commit a8d0e0e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 46 deletions.
26 changes: 1 addition & 25 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,23 +413,6 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo
) + extra_suffix


def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, ObjectStore]], source_path: str):
if source_path.endswith('.symlink') or os.path.islink(source_path):
source_path = extract_path_from_symlink(source_path, object_store=object_store)
metadata_path = str(Path(source_path) / Path('.metadata'))
if object_store is None:
return not os.path.exists(metadata_path)
else:
try:
_, _, metadata_path = parse_uri(metadata_path)
with tempfile.TemporaryDirectory() as temp_dir:
metadata_destination = os.path.join(str(temp_dir), '.metadata')
download_object_or_file(metadata_path, metadata_destination, object_store)
return False
except FileNotFoundError:
return True


def load_checkpoint(
path: str,
state: State,
Expand Down Expand Up @@ -531,15 +514,8 @@ def load_checkpoint(
:attr:`load_weights_only` is not None. Otherwise, None.
"""
path = partial_format(path, run_name=state.run_name)
using_legacy_sharded = False
if state.fsdp_sharded_state_dict_enabled:
assert object_store is None or isinstance(
object_store,
ObjectStore,
), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore'
using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path)

if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded:
if state.fsdp_sharded_state_dict_enabled:
rng_state_dicts = load_sharded_checkpoint(
source_path=path,
state=state,
Expand Down
22 changes: 1 addition & 21 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from composer.optim import DecoupledAdamW
from composer.trainer import Trainer
from composer.utils import FSDPConfig, TPConfig, dist, parse_uri
from composer.utils.checkpoint import dist_cp_load, is_checkpoint_legacy_sharded
from composer.utils.checkpoint import dist_cp_load
from composer.utils.file_helpers import get_file
from composer.utils.object_store import S3ObjectStore
from composer.utils.reproducibility import get_rng_state
Expand Down Expand Up @@ -548,10 +548,6 @@ def test_fsdp_load_old_checkpoint(
load_path_dir = (load_path_dir + 'ep0-ba2/')

load_path = load_path_dir + f'ba2_rank{rank}.pt'
assert is_checkpoint_legacy_sharded(
object_store=S3ObjectStore(bucket=f'{s3_bucket}'),
source_path=load_path.lstrip(f's3://{s3_bucket}/'),
)
else:
load_path = (
f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/'
Expand Down Expand Up @@ -909,16 +905,9 @@ def test_fsdp_partitioned_state_dict_load(
load_path = 's3://' + save_folder.strip('s3://').format(
run_name=run_name,
) + ('/ba2' if not use_symlink else '/latest-rank0.pt.symlink')
object_store = S3ObjectStore(bucket=f'{s3_bucket}')
else:
object_store = None
load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2'))

assert not is_checkpoint_legacy_sharded(
object_store=object_store,
source_path=load_path.replace(f's3://{s3_bucket}/', ''),
)

if autoresume:
load_path = None
trainer2 = get_trainer(
Expand Down Expand Up @@ -1013,10 +1002,6 @@ def test_elastic_resumption(
else:
save_folder = None
sharded_load_path = os.path.join(base_path, 'ba2')
assert not is_checkpoint_legacy_sharded(
object_store=S3ObjectStore(bucket=f'{s3_bucket}'),
source_path=sharded_load_path.replace(f's3://{s3_bucket}/', ''),
)

sharded_trainer = get_trainer(
save_folder=save_folder,
Expand Down Expand Up @@ -1237,11 +1222,6 @@ def set_up_planner(

load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2'))

assert not is_checkpoint_legacy_sharded(
object_store=None,
source_path=load_path,
)

trainer2 = get_trainer(
save_folder=str(save_folder),
load_path=load_path,
Expand Down

0 comments on commit a8d0e0e

Please sign in to comment.