Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save checkpoint to disk for API with new save layout #3399

Merged
merged 70 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
b7915b9
add stubs
eracah Jun 4, 2024
5700fd4
save progress
eracah Jun 5, 2024
cfa11d4
Merge branch 'dev' of https://github.com/mosaicml/composer into sv-sd
eracah Jun 6, 2024
eabb9a3
add full state dict saving and testing
eracah Jun 6, 2024
37cfed3
Merge branch 'dev' into sv-sd
eracah Jun 6, 2024
0c503dc
Merge branch 'dev' of https://github.com/mosaicml/composer into sv-sd
eracah Jun 6, 2024
13d7e0c
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 6, 2024
316504e
add stubs
eracah Jun 6, 2024
00a7ce2
implement sharded save and get tests to pass
eracah Jun 6, 2024
0fd0d53
Merge branch 'dev' into sv-sd
eracah Jun 6, 2024
3a6c185
add cpu sharded test
eracah Jun 6, 2024
9b79b30
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 6, 2024
56f9339
pre-commit
eracah Jun 7, 2024
03dbb6a
remove comment
eracah Jun 7, 2024
7efdc74
Merge branch 'dev' into sv-sd
eracah Jun 7, 2024
e437fa7
remove __init__
eracah Jun 7, 2024
b518cde
rm init
eracah Jun 7, 2024
e9bb7ef
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 7, 2024
5156572
fix
eracah Jun 7, 2024
4f516dc
remove torchmetrics
eracah Jun 7, 2024
66c84b8
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 7, 2024
8c076b9
Update composer/checkpoint/save.py
eracah Jun 10, 2024
dba8cdc
Update composer/checkpoint/save.py
eracah Jun 10, 2024
dc35df0
Update composer/checkpoint/save.py
eracah Jun 10, 2024
dc5cf9f
Update composer/checkpoint/save.py
eracah Jun 10, 2024
6c67f93
Merge branch 'dev' into sv-sd
eracah Jun 10, 2024
892536d
fix docstring
eracah Jun 10, 2024
fcd1789
remove time.sleep
eracah Jun 10, 2024
1b47500
fix cpu tests
eracah Jun 11, 2024
46e4bec
pre-commit
eracah Jun 11, 2024
c4a97ce
fix cpu test
eracah Jun 11, 2024
05b9903
remove cpu tests :(
eracah Jun 11, 2024
44b123f
Merge branch 'dev' into sv-sd
eracah Jun 11, 2024
3874e58
pre-commit
eracah Jun 11, 2024
037548a
pc
eracah Jun 11, 2024
8324589
add all check
eracah Jun 12, 2024
ae3911d
pre-commit
eracah Jun 12, 2024
3e6e60b
add world_size = 1 test
eracah Jun 12, 2024
c865023
Merge branch 'sv-sd' of https://github.com/eracah/evan-composer into …
eracah Jun 12, 2024
bd885a6
pre-commit
eracah Jun 12, 2024
064d0ba
pc
eracah Jun 12, 2024
38e1f12
first commit
eracah Jun 13, 2024
d44b85c
c
eracah Jun 13, 2024
a3db90a
add save_checkpoint
eracah Jun 14, 2024
aee70ca
Merge branch 'dev' into save-layout
eracah Jun 17, 2024
e954b9c
ckpt
eracah Jun 17, 2024
67a4449
add metadata saving test
eracah Jun 18, 2024
b88caf5
add final tets
eracah Jun 18, 2024
3e3ce20
Merge branch 'save-layout' of https://github.com/eracah/evan-composer…
eracah Jun 18, 2024
2641031
pre-commit
eracah Jun 18, 2024
6633d53
Merge branch 'dev' of https://github.com/mosaicml/composer into save-…
eracah Jun 18, 2024
b7618df
Merge branch 'dev' into save-layout
eracah Jun 18, 2024
ed7da19
pre-commit
eracah Jun 18, 2024
5e4752a
Merge branch 'save-layout' of https://github.com/eracah/evan-composer…
eracah Jun 18, 2024
eb898fc
pre-commit
eracah Jun 18, 2024
b06af2b
fix
eracah Jun 18, 2024
9609478
pc
eracah Jun 18, 2024
945b18f
Merge branch 'save-layout' of https://github.com/eracah/evan-composer…
eracah Jun 18, 2024
010539c
docstrign
eracah Jun 18, 2024
9dde2cb
fix test
eracah Jun 18, 2024
0978d10
Merge branch 'save-layout' of https://github.com/eracah/evan-composer…
eracah Jun 18, 2024
fdbd1c0
refactor state dict test
eracah Jun 19, 2024
9122db1
pre-commit
eracah Jun 19, 2024
b549b3c
pc
eracah Jun 19, 2024
d33f5b8
Merge branch 'dev' into save-layout
eracah Jun 20, 2024
55a344a
Update composer/checkpoint/save.py
eracah Jun 20, 2024
4369e5a
fix test
eracah Jun 21, 2024
e6b536d
Merge branch 'save-layout' of https://github.com/eracah/evan-composer…
eracah Jun 21, 2024
0dc61f1
s
eracah Jun 21, 2024
49902fe
s
eracah Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@
is_model_deepspeed,
partial_format,
)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME
from composer.utils.compression import get_compressor, is_compressed_pt
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

log = logging.getLogger(__name__)

__all__ = ['CheckpointSaver']

_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata'


class CheckpointSaver(Callback): # noqa: D101
__doc__ = f"""Callback to save checkpoints.
Expand Down
286 changes: 283 additions & 3 deletions composer/checkpoint/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,293 @@

"""Useful functions for saving state dicts to disk."""

import json
import logging
import os
import pickle
import textwrap
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union

import torch
import torch.distributed.checkpoint as DCP
from packaging import version
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor

from composer.checkpoint.state_dict import (
get_metadata_state_dict,
get_model_state_dict,
get_optim_state_dict,
get_resumption_state_dict,
)
from composer.core import State, Time
from composer.devices import Device
from composer.models import ComposerModel
from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file
from composer.utils.file_helpers import format_name_with_dist_and_time

log = logging.getLogger(__name__)

MODEL_CHECKPOINT_DIRECTORY_NAME = 'model'
MONOLITHIC_MODEL_CHECKPOINT_FILENAME = 'model.pt'
OPTIM_CHECKPOINT_DIRECTORY_NAME = 'optim'
OPTIM_MONO_CHECKPOINT_FILENAME = 'optim.pt'
METADATA_CHECKPOINT_FILENAME = 'composer_metadata.json'
RESUMPTION_CHECKPOINT_FILENAME = 'resumption.pkl'


@dataclass
class CheckpointSaveOptions:
"""Options for saving a checkpoint to disk.

Args:
destination_dir (str): The directory to save the checkpoint to.
save_frequency (Union[str, int, Time]): The frequency to save the checkpoint.
If '1ep', the checkpoint will be saved after each epoch.
If '1ba', the checkpoint will be saved after each batch.
If an int, the checkpoint will be saved after that many epochs.
dir_prefix (str): The prefix to use for the directory name. Can include {epoch} and {batch}.
overwrite (bool): Whether to overwrite the checkpoint if it already exists.
save_model (bool): Whether to save the model.
save_optimizer (bool): Whether to save the optimizer.
save_resumption_state (bool): Whether to save the resumption state.
num_checkpoints_to_keep (int): The number of checkpoints to keep.
If -1, all checkpoints will be kept.
save_format (str): The format to save the model in. 'pt', which is the standard pytorch serializarion, is the only option for now.
sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint.
precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model.
ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model.
"""
destination_dir: str
save_frequency: Union[str, int, Time] = '1ep'
dir_prefix: str = 'ep{epoch}-ba{batch}'
overwrite: bool = False
save_model: bool = True
save_optimizer: bool = True
save_resumption_state: bool = True
num_checkpoints_to_keep: int = -1
save_format: str = 'pt'
sharded_checkpoint: bool = False
precision: str = 'bf16'
# High level objects to save or not save
# e.g. 'model', 'optim', 'schedulers', 'rng' etc.
eracah marked this conversation as resolved.
Show resolved Hide resolved
include_keys: Optional[Union[str, Sequence[str]]] = None
ignore_keys: Optional[Union[str, Sequence[str]]] = None


def save_checkpoint_to_disk(
state: State,
options: Optional[Union[CheckpointSaveOptions, Dict]] = None,
destination_dir: Optional[str] = None,
):
"""Saves a checkpoint to disk.

Args:
state (State): The state to save.
options (Optional[Union[CheckpointSaveOptions, Dict]]): The options for saving the checkpoint.
If None, destination_dir must be provided.
destination_dir (Optional[str]): The directory to save the checkpoint to.
If options is provided, this will overwrite options.destination_dir.
"""
if options is None:
if destination_dir is None:
raise ValueError('destination_dir must be provided if options is None')
options = CheckpointSaveOptions(destination_dir=destination_dir)
else:
if isinstance(options, Dict):
options = CheckpointSaveOptions(**options)
if destination_dir is not None:
options.destination_dir = destination_dir
save_path = os.path.join(options.destination_dir, options.dir_prefix)
save_path = format_name_with_dist_and_time(save_path, state.run_name, state.timestamp)
os.makedirs(save_path, exist_ok=True)
if options.save_model:
save_model_to_disk(
state.model,
save_path,
options.sharded_checkpoint,
options.precision,
options.include_keys,
options.ignore_keys,
options.overwrite,
options.save_format,
)
if options.save_optimizer:
optimizer = state.optimizers[0]
save_optim_to_disk(
state.model,
optimizer,
save_path,
options.sharded_checkpoint,
options.precision,
options.overwrite,
options.save_format,
)
if options.save_resumption_state:
save_resumption_state_to_disk(state, save_path)

save_composer_metadata_to_disk(
save_path,
state.model,
options.sharded_checkpoint,
options.precision,
state.device,
state.device_train_microbatch_size,
)


def save_model_to_disk(
model: Union[ComposerModel, torch.nn.Module],
destination_dir: str,
sharded_checkpoint: bool = False,
precision: str = 'fp32',
include_keys: Optional[Union[str, Sequence[str]]] = None,
ignore_keys: Optional[Union[str, Sequence[str]]] = None,
overwrite: bool = False,
save_format: str = 'pt', # or hf, safetensor
) -> Optional[str]:
"""Saves a model to disk.

Args:
model (Union[ComposerModel, torch.nn.Module]): The model to save.
destination_dir (str): The directory to save the model to.
Model will be saved as distination_dir/models/model.pt if sharded_checkpoint is False,
otherwise all shards will be saved as destination_dir/models/__<rank>_0.distcp.
sharded_checkpoint (bool): Whether to save the model as a sharded checkpoint.
precision (str): The precision to save the model in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
include_keys (Optional[Union[str, Sequence[str]]]): Keys to include in the saved model.
ignore_keys (Optional[Union[str, Sequence[str]]]): Keys to ignore in the saved model.
overwrite (bool): If True, the file will be overwritten if it exists.
save_format (str): The format to save the model in. One of 'pt', 'hf', or 'safetensor'.

Returns:
str: The full path to the saved model.
"""
if save_format != 'pt':
raise NotImplementedError(
f"Saving checkpoint in format {save_format} is not supported. Please choose from ['pt'].",
)
model_state_dict = get_model_state_dict(
model,
sharded_checkpoint,
precision,
include_keys,
ignore_keys,
)

destination_file_path = (
os.path.join(destination_dir, MODEL_CHECKPOINT_DIRECTORY_NAME) if sharded_checkpoint else
os.path.join(destination_dir, MODEL_CHECKPOINT_DIRECTORY_NAME, MONOLITHIC_MODEL_CHECKPOINT_FILENAME)
)
saved_path = save_state_dict_to_disk(
state_dict=model_state_dict,
destination_file_path=destination_file_path,
overwrite=overwrite,
save_format=save_format,
)
return saved_path


def save_optim_to_disk(
model: Union[ComposerModel, torch.nn.Module],
optimizer: torch.optim.Optimizer,
destination_dir: str,
sharded_checkpoint: bool = False,
precision: str = 'fp32',
overwrite: bool = False,
save_format: str = 'pt',
) -> Optional[str]:
"""Saves an optimizer to disk.

Args:
model (Union[ComposerModel, torch.nn.Module]): The model to save.
optimizer (torch.optim.Optimizer): The optimizer to save.
destination_dir (str): The directory to save the optimizer to.
Optimizer will be saved as destination_dir/optim/optim.pt if sharded_checkpoint is False,
otherwise all shards will be saved as destination_dir/optim/__<rank>_0.distcp.
sharded_checkpoint (bool): Whether to save the optimizer as a sharded checkpoint.
precision (str): The precision to save the optimizer in. One of 'bf16', 'fp32', 'fp16', 'fp64'.
overwrite (bool): If True, the file will be overwritten if it exists.
save_format (str): The format to save the optimizer in. One of 'pt'.
"""
optim_state_dict = get_optim_state_dict(
model,
optimizer,
sharded_state_dict=sharded_checkpoint,
precision=precision,
)
destination_file_path = os.path.join(destination_dir,
OPTIM_CHECKPOINT_DIRECTORY_NAME) if sharded_checkpoint else os.path.join(
destination_dir,
OPTIM_CHECKPOINT_DIRECTORY_NAME,
OPTIM_MONO_CHECKPOINT_FILENAME,
)
saved_path = save_state_dict_to_disk(
state_dict=optim_state_dict,
destination_file_path=destination_file_path,
overwrite=overwrite,
save_format=save_format,
)

return saved_path


def save_composer_metadata_to_disk(
destination_dir: str,
model: Optional[Union[ComposerModel, torch.nn.Module]] = None,
sharded_state_dict: Optional[bool] = None,
precision: Optional[Union[str, torch.dtype]] = None,
device: Optional[Device] = None,
device_train_microbatch_size: Optional[Union[int, float]] = None,
):
"""Saves metadata about the model to disk.

Args:
destination_dir (str): The directory to save the metadata to.
model (Optional[Union[ComposerModel, torch.nn.Module]]): The model to save metadata about.
sharded_state_dict (Optional[bool]): Whether the model is sharded.
precision (Optional[Union[str, torch.dtype]]): The precision of the model.
device (Optional[Device]): The device the model is on.
device_train_microbatch_size (Optional[Union[int, float]]): The device train microbatch size.
"""
md_dict = get_metadata_state_dict(
model,
sharded_state_dict,
precision,
device,
device_train_microbatch_size,
)
os.makedirs(destination_dir, exist_ok=True)
destination_file_path = os.path.join(destination_dir, METADATA_CHECKPOINT_FILENAME)

if dist.get_global_rank() == 0:
with open(destination_file_path, 'w') as f:
json.dump(md_dict, f, indent=4)
return destination_file_path


def save_resumption_state_to_disk(
state: State,
destination_dir: str,
):
"""Saves the resumption state to disk.

Args:
state (State): The state to save.
destination_dir (str): The directory to save the resumption state to.
"""
resumption_state_dict = get_resumption_state_dict(state)
destination_file_path = os.path.join(destination_dir, RESUMPTION_CHECKPOINT_FILENAME)
with open(destination_file_path, 'wb') as f:
pickle.dump(resumption_state_dict, f)
return destination_file_path


from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file

Expand Down Expand Up @@ -80,6 +354,8 @@ def _save_sharded_state_dict_to_disk(
)
destination_file_path = stripped_path

# Wait for all ranks to get here before checking if the directory exists.
dist.barrier()
if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.')

Expand All @@ -94,6 +370,9 @@ def _save_sharded_state_dict_to_disk(
else:
DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))

log.debug(
f'Finished saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}',
)
return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME


Expand All @@ -106,13 +385,14 @@ def _save_full_state_dict_to_disk(

if save_format != 'pt':
raise NotImplementedError(
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
f"Saving full state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
)

if not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.')

if dist.get_global_rank() == 0:
os.makedirs(os.path.dirname(destination_file_path), exist_ok=True)
_write_checkpoint_file(state_dict=state_dict, filename=destination_file_path)
return destination_file_path
return None
Expand All @@ -130,7 +410,7 @@ def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool:
for value in state_dict.values():
if isinstance(value, ShardedTensor) or isinstance(value, DTensor):
return True
if isinstance(value, Dict):
elif isinstance(value, Dict):
is_sharded = is_state_dict_sharded(value)
if is_sharded:
return True
Expand Down
2 changes: 1 addition & 1 deletion composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def get_metadata_state_dict(
sharded_state_dict: Optional[bool] = None,
precision: Optional[Union[str, torch.dtype]] = None,
device: Optional[Device] = None,
device_train_microbatch_size: Optional[int] = None,
device_train_microbatch_size: Optional[Union[int, float]] = None,
) -> dict[str, Any]:
"""Generate the metadata and integrations for a training run.
Expand Down
1 change: 1 addition & 0 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
_COMPOSER_STATES_FILENAME = 'composer_states.pt'
_DEEPSPEED_TAG = 'deepspeed' # always tag with the same, deterministic name. We'll rename the tarball to the appropriate name.
_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME = f'__{dist.get_global_rank()}_0.distcp'
_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata'


def _get_checkpoint_validation_function(
Expand Down
Loading
Loading