Skip to content

Commit

Permalink
Remove Legacy Checkpointing (#3631)
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 authored Oct 1, 2024
1 parent e5e2f74 commit 2963293
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 66 deletions.
52 changes: 9 additions & 43 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,24 +413,6 @@ def format(self, state: State, is_deepspeed: bool = False, keep_placeholders: bo
) + extra_suffix


def is_checkpoint_legacy_sharded(object_store: Optional[Union[LoggerDestination, ObjectStore]], source_path: str):
if source_path.endswith('.symlink') or os.path.islink(source_path):
source_path = extract_path_from_symlink(source_path, object_store=object_store)
metadata_path = str(Path(source_path) / Path('.metadata'))
log.debug(f'Checking if checkpoint is legacy sharded by checking for metadata file at {metadata_path}.')
if object_store is None:
return not os.path.exists(metadata_path)
else:
try:
_, _, metadata_path = parse_uri(metadata_path)
with tempfile.TemporaryDirectory() as temp_dir:
metadata_destination = os.path.join(str(temp_dir), '.metadata')
download_object_or_file(metadata_path, metadata_destination, object_store)
return False
except FileNotFoundError:
return True


def load_checkpoint(
path: str,
state: State,
Expand Down Expand Up @@ -533,16 +515,8 @@ def load_checkpoint(
"""
path = partial_format(path, run_name=state.run_name)
log.debug(f'Loading checkpoint from formatted path: {path}')
using_legacy_sharded = False

if state.fsdp_sharded_state_dict_enabled:
assert object_store is None or isinstance(
object_store,
ObjectStore,
), 'For loading sharded checkpoints load_object_store must be set with the class ObjectStore'
using_legacy_sharded = is_checkpoint_legacy_sharded(object_store, path)
log.info(f'Using legacy sharded checkpoint: {using_legacy_sharded}')

if state.fsdp_sharded_state_dict_enabled and not using_legacy_sharded:
rng_state_dicts = load_sharded_checkpoint(
source_path=path,
state=state,
Expand All @@ -557,26 +531,20 @@ def load_checkpoint(
)
else:
# Download the checkpoint to the node-local folder
log.debug('Loading checkpoint at %s', path)
# Each node gets one unique folder to store checkpoints that is shared amongst all local ranks in that node.
# If fsdp sharded state_dicts is enabled then EVERY rank gets a unique checkpoint folder.
needs_unique_checkpoint_folder = state.fsdp_sharded_state_dict_enabled or dist.get_local_rank() == 0
tempdir_ctx = tempfile.TemporaryDirectory() if needs_unique_checkpoint_folder else contextlib.nullcontext(None)
tempdir_ctx = tempfile.TemporaryDirectory() if dist.get_local_rank() == 0 else contextlib.nullcontext(None)
with tempdir_ctx as tempdir:
try:
# Get the path to the proper checkpoint folder corresponding to the current rank's node.
# If fsdp_sharded_state_dict_enabled then just use that rank's unique tempdir.
node_checkpoint_folder = (
tempdir if state.fsdp_sharded_state_dict_enabled else _get_local_rank_zero_path(tempdir)
)
assert node_checkpoint_folder is not None
node_checkpoint_folder = _get_local_rank_zero_path(tempdir)

composer_states_filepath, extracted_checkpoint_folder, extracted_rank_n = download_checkpoint(
path=path,
node_checkpoint_folder=node_checkpoint_folder,
object_store=object_store,
progress_bar=progress_bar,
fsdp_sharded_state_dict_enabled=state.fsdp_sharded_state_dict_enabled,
deepspeed_sharded_checkpoint=is_model_deepspeed(state.model),
)
rng_state_dicts = _restore_checkpoint(
Expand All @@ -596,6 +564,8 @@ def load_checkpoint(
# be a shared resource between nodes.
dist.barrier()
log.info('%s loaded from %s', 'Model weights' if load_weights_only else 'Trainer checkpoint', path)

# Verify all ranks resumed on same step
step_to_resume_from = state.timestamp.batch.value
max_step_to_resume_from = state.device.tensor_to_device(
torch.tensor(state.timestamp.batch.value, dtype=torch.int64),
Expand Down Expand Up @@ -802,7 +772,6 @@ def download_checkpoint(
node_checkpoint_folder: str,
object_store: Optional[Union[ObjectStore, LoggerDestination]],
progress_bar: bool,
fsdp_sharded_state_dict_enabled: bool = False,
deepspeed_sharded_checkpoint: bool = False,
) -> tuple[str, Optional[str], bool]:
"""Download the checkpoint stored at ``path``, potentially in ``object_store``, to ``node_checkpoint_folder``.
Expand All @@ -829,9 +798,7 @@ def download_checkpoint(
# and only rank zero has this file unless fsdp_sharded_state_dict_enabled then
# every rank has it's own file.
extracted_checkpoint_folder = None
composer_states_filepath = (
rank_n_checkpoint_filepath if fsdp_sharded_state_dict_enabled else rank_zero_checkpoint_filepath
)
composer_states_filepath = rank_zero_checkpoint_filepath

if is_compressed_pt(path):
original_path = path
Expand All @@ -841,9 +808,8 @@ def download_checkpoint(
with compressor.decompress(original_path) as in_file:
shutil.copyfileobj(in_file, out_file)

checkpoint_is_sharded = fsdp_sharded_state_dict_enabled or deepspeed_sharded_checkpoint
try:
if not checkpoint_is_sharded and dist.get_local_rank() == 0:
if not deepspeed_sharded_checkpoint and dist.get_local_rank() == 0:
# If the checkpoint is not sharded, then local rank 0 on each node needs to download the
# global rank 0 checkpoint
path = _format_path_with_rank_zero(path)
Expand All @@ -862,7 +828,7 @@ def download_checkpoint(
# the underlying issue is that the checkpoint file does not exist on the disk
# or could not be downloaded
raise RuntimeError(f'Checkpoint {path} does not exist')
elif checkpoint_is_sharded:
elif deepspeed_sharded_checkpoint:
# If the checkpoint is sharded, then every rank needs to download its own checkpoint
path = _format_path_with_current_rank(path)
try:
Expand Down Expand Up @@ -892,7 +858,7 @@ def download_checkpoint(

finally:
# Use busy wait to avoid timeouts on large downloads for non-sharded checkpoints
if not checkpoint_is_sharded:
if not deepspeed_sharded_checkpoint:
signal_file_path = os.path.join(
node_checkpoint_folder,
dist.get_node_signal_file_name(),
Expand Down
27 changes: 4 additions & 23 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from composer.optim import DecoupledAdamW
from composer.trainer import Trainer
from composer.utils import FSDPConfig, TPConfig, dist, parse_uri
from composer.utils.checkpoint import dist_cp_load, is_checkpoint_legacy_sharded
from composer.utils.checkpoint import dist_cp_load
from composer.utils.file_helpers import get_file
from composer.utils.object_store import S3ObjectStore
from composer.utils.reproducibility import get_rng_state
Expand Down Expand Up @@ -539,7 +539,8 @@ def test_fsdp_load_old_checkpoint(
pytest.skip('Current torch version is older than torch version that checkpoint was written with.')

if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']:
rank = 0 if state_dict_type == 'full' else '{rank}'
if state_dict_type == 'sharded':
pytest.skip('Loading legacy sharded checkpoints are not supported after v0.25.0.')

load_path_dir = (
f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/'
Expand All @@ -549,11 +550,7 @@ def test_fsdp_load_old_checkpoint(
if ((version.parse(composer_version) > version.parse('0.15.0')) and state_dict_type != 'full'):
load_path_dir = (load_path_dir + 'ep0-ba2/')

load_path = load_path_dir + f'ba2_rank{rank}.pt'
assert is_checkpoint_legacy_sharded(
object_store=S3ObjectStore(bucket=f'{s3_bucket}'),
source_path=load_path.lstrip(f's3://{s3_bucket}/'),
)
load_path = load_path_dir + f'ba2_rank0.pt'
else:
load_path = (
f's3://{s3_bucket}/{s3_read_only_prefix}/backwards_compatibility/'
Expand Down Expand Up @@ -911,16 +908,9 @@ def test_fsdp_partitioned_state_dict_load(
load_path = 's3://' + save_folder.strip('s3://').format(
run_name=run_name,
) + ('/ba2' if not use_symlink else '/latest-rank0.pt.symlink')
object_store = S3ObjectStore(bucket=f'{s3_bucket}')
else:
object_store = None
load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2'))

assert not is_checkpoint_legacy_sharded(
object_store=object_store,
source_path=load_path.replace(f's3://{s3_bucket}/', ''),
)

if autoresume:
load_path = None
trainer2 = get_trainer(
Expand Down Expand Up @@ -1015,10 +1005,6 @@ def test_elastic_resumption(
else:
save_folder = None
sharded_load_path = os.path.join(base_path, 'ba2')
assert not is_checkpoint_legacy_sharded(
object_store=S3ObjectStore(bucket=f'{s3_bucket}'),
source_path=sharded_load_path.replace(f's3://{s3_bucket}/', ''),
)

sharded_trainer = get_trainer(
save_folder=save_folder,
Expand Down Expand Up @@ -1239,11 +1225,6 @@ def set_up_planner(

load_path = str(save_folder.format(run_name=run_name) / pathlib.Path('ba2'))

assert not is_checkpoint_legacy_sharded(
object_store=None,
source_path=load_path,
)

trainer2 = get_trainer(
save_folder=str(save_folder),
load_path=load_path,
Expand Down

0 comments on commit 2963293

Please sign in to comment.