Skip to content

Commit

Permalink
[ckpt-rewr] Save state dict API (#3372)
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored Jun 17, 2024
1 parent 3859366 commit f1cfc64
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 101 deletions.
145 changes: 145 additions & 0 deletions composer/checkpoint/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for saving state dicts to disk."""

import logging
import os
import textwrap
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Union

import torch
import torch.distributed.checkpoint as DCP
from packaging import version
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor

from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME, _write_checkpoint_file

log = logging.getLogger(__name__)


def save_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt', # or hf, safetensor
) -> Optional[str]:
"""Saves a state dict to local disk.
Args:
state_dict (Dict[str,Any]): The state dict to save.
destination_file_path (str): The path to save the state dict to. If sharded,
this should be the pth to a directory. Otherwise, it should be a path to a file.
overwrite (bool): If True, the file will be overwritten if it exists.
save_format (str): The format to save the state dict in. One of 'pt', 'hf', or 'safetensor'.
Returns:
str: The full path to the saved state dict if (sharded is false and rank 0) or if sharded is true, otherwise None.
"""
if state_dict == {}:
return None
if is_state_dict_sharded(state_dict):
path_saved = _save_sharded_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format)
else:
if dist.get_global_rank() == 0:
path_saved = _save_full_state_dict_to_disk(state_dict, destination_file_path, overwrite, save_format)
else:
path_saved = None

return path_saved


def _save_sharded_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt',
) -> Optional[str]:

if save_format != 'pt':
raise NotImplementedError(
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
)

if state_dict == {}:
return None

# If user specifies filename instead of directory suffixes, strip them and warn
if len(Path(destination_file_path).suffixes) > 0:
stripped_path = _strip_suffixes(destination_file_path)
warnings.warn(
textwrap.dedent(
f"""Sharded checkpoints require a directory path not a file path:
{destination_file_path} will have its extensions stripped and checkpoints will be saved in {stripped_path}
as {stripped_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}""",
),
)
destination_file_path = stripped_path

if dist.get_global_rank() == 0 and not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'Directory {destination_file_path} already exists. Set overwrite=True to overwrite it.')

log.debug(
f'Starting saving of sharded state dict to {destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}',
)

# For 2.3.0 and above you can use checkpoint_id, but this version works the best for all versions
# of torch (and makes pyright happier) that we support, so we use it for now.
if version.parse(torch.__version__) < version.parse('2.2.0'):
DCP.save_state_dict(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))
else:
DCP.save(state_dict=state_dict, storage_writer=DCP.FileSystemWriter(destination_file_path))

return destination_file_path + '/' + _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME


def _save_full_state_dict_to_disk(
state_dict: Dict[str, Any],
destination_file_path: str,
overwrite: bool = False,
save_format: str = 'pt', # or hf, safetensor
) -> Optional[str]:

if save_format != 'pt':
raise NotImplementedError(
f"Saving sharded state dict to disk in format {save_format} is not supported. Please choose from ['pt'].",
)

if not overwrite and os.path.exists(destination_file_path):
raise ValueError(f'File {destination_file_path} already exists. Set overwrite=True to overwrite it.')

if dist.get_global_rank() == 0:
_write_checkpoint_file(state_dict=state_dict, filename=destination_file_path)
return destination_file_path
return None


def is_state_dict_sharded(state_dict: Dict[str, Any]) -> bool:
"""Determines if the state dict is sharded.
Args:
state_dict (Dict[str, Any]): The state dict to check.
Returns:
bool: Whether the state dict is sharded.
"""
for value in state_dict.values():
if isinstance(value, ShardedTensor) or isinstance(value, DTensor):
return True
if isinstance(value, Dict):
is_sharded = is_state_dict_sharded(value)
if is_sharded:
return True
return False


def _strip_suffixes(path: Union[str, Path]) -> str:
path = Path(path)
for _ in path.suffixes:
path = path.with_suffix('')

return str(path)
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def _get_commit_sha() -> str:
'torch': ('https://pytorch.org/docs/stable/', None),
'torchvision': ('https://pytorch.org/vision/stable/', None),
'torchtext': ('https://pytorch.org/text/stable/', None),
'torchmetrics': ('https://torchmetrics.readthedocs.io/en/latest/', None),
'libcloud': ('https://libcloud.readthedocs.io/en/stable/', None),
'PIL': ('https://pillow.readthedocs.io/en/stable', None),
'coolname': ('https://coolname.readthedocs.io/en/latest/', None),
Expand Down
110 changes: 110 additions & 0 deletions tests/checkpoint/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import CPUOffload
from torch.optim import adam

from tests.common.models import EvenSimplerMLP, SimpleComposerMLP

__all__ = [
'init_model_and_optimizer',
'init_model',
'init_optimizer',
]


def init_model_and_optimizer(
use_composer_model: bool,
num_classes=3,
batch_size=5,
num_features=8,
take_step=True,
use_fsdp=False,
tensor_type='sharded_tensor',
device='cuda',
):
model, loss_fn = init_model(
use_composer_model,
num_classes=num_classes,
num_features=num_features,
use_fsdp=use_fsdp,
tensor_type=tensor_type,
device=device,
)

optimizer = init_optimizer(
model,
loss_fn,
use_composer_model=use_composer_model,
num_classes=num_classes,
batch_size=batch_size,
num_features=num_features,
take_step=take_step,
device=device,
)

return model, optimizer


def init_model(
use_composer_model: bool = False,
num_classes=3,
num_features=8,
use_fsdp=False,
device='cuda',
tensor_type='sharded_tensor',
sync_module_states=True,
cpu_offload=False,
):
if use_composer_model:
model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device=device)
loss_fn = model._loss_fn
else:
model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device=device)
loss_fn = torch.nn.CrossEntropyLoss()

if use_fsdp:
fsdp_kwargs: Dict[str, Any] = dict(
use_orig_params=True,
sync_module_states=sync_module_states, # To enable easy comparison between rank 0 unsharded model and full state dict
cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None,
device_id=torch.device('cpu') if device == 'cpu' else None,
)

if tensor_type == 'dtensor':
from torch.distributed.device_mesh import init_device_mesh
device_mesh = init_device_mesh('cuda', (2,))
fsdp_kwargs['device_mesh'] = device_mesh

model = FSDP(
model,
**fsdp_kwargs,
)

return model, loss_fn


def init_optimizer(
model,
loss_fn,
use_composer_model: bool = False,
num_classes=3,
batch_size=5,
num_features=8,
take_step=True,
device='cuda',
):
inputs = torch.randn(batch_size, num_features, device=device)
targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device, dtype=torch.long)
batch = (inputs, targets) if use_composer_model else inputs
optimizer = adam.Adam(model.parameters())
outputs = model(batch)
loss = loss_fn(outputs, targets)
loss.backward()
if take_step:
optimizer.step()
return optimizer
79 changes: 79 additions & 0 deletions tests/checkpoint/test_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import os
import time
import uuid
from copy import deepcopy
from pathlib import Path

import pytest
import torch
import torch.distributed.checkpoint as DCP
from packaging import version

from composer.checkpoint.save import save_state_dict_to_disk
from composer.checkpoint.state_dict import get_model_state_dict
from composer.utils import dist
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
from tests.checkpoint.helpers import init_model
from tests.common.compare import deep_compare
from tests.common.markers import world_size


@world_size(1, 2)
@pytest.mark.gpu
@pytest.mark.parametrize('sharded_model', [False, True])
def test_save_full_state_dict_to_disk(world_size: int, tmp_path: str, sharded_model: bool):
if world_size == 1 and sharded_model:
pytest.skip("Can't have a sharded model for world_size = 1")
destination_file_path = os.path.join(tmp_path, 'test.pt')
use_fsdp = sharded_model
model, _ = init_model(use_fsdp=use_fsdp, device='cuda', sync_module_states=True)

state_dict = get_model_state_dict(model, sharded_state_dict=False)
path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path)
time.sleep(1)
if dist.get_global_rank() == 0:
assert path_saved is not None
assert path_saved == destination_file_path
assert os.path.exists(destination_file_path), f'{destination_file_path} does not exist'
loaded_state_dict = torch.load(path_saved, map_location='cuda')
deep_compare(state_dict, loaded_state_dict)
else:
assert path_saved is None


@world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize(
'tensor_type',
[
'sharded_tensor',
pytest.param(
'dtensor',
marks=pytest.mark.skipif(
version.parse(torch.__version__) < version.parse('2.2.0'),
reason='Requires torch>=2.2.0 for dtensor',
),
),
],
)
def test_save_sharded_state_dict_to_disk(world_size: int, tmp_path: str, tensor_type: str):

destination_file_path = os.path.join(tmp_path, str(uuid.uuid4())[:8])
# Sync the path across all ranks
destination_file_path = dist.all_gather_object(destination_file_path)[0]
model, _ = init_model(use_fsdp=True, device='cuda', tensor_type=tensor_type)

state_dict = get_model_state_dict(model, sharded_state_dict=True)
loaded_in_state_dict = deepcopy(state_dict)
path_saved = save_state_dict_to_disk(state_dict, destination_file_path=destination_file_path, overwrite=True)
assert path_saved == f'{destination_file_path}/{_TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME}'
assert path_saved is not None
load_path = str(Path(path_saved).parent)
if version.parse(torch.__version__) < version.parse('2.2.0'):
DCP.load_state_dict(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path))
else:
DCP.load(state_dict=loaded_in_state_dict, storage_reader=DCP.FileSystemReader(load_path))
deep_compare(state_dict, loaded_in_state_dict)
Loading

0 comments on commit f1cfc64

Please sign in to comment.