Skip to content

Commit

Permalink
Merge branch 'dev' of github.com-mvpatel2000:mosaicml/composer into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpatel2000 committed Jun 5, 2024
2 parents de46308 + fe6140e commit d0712dc
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 14 deletions.
8 changes: 7 additions & 1 deletion composer/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@

"""Module for checkpointing API."""

from composer.checkpoint.state_dict import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict
from composer.checkpoint.state_dict import (
get_metadata_state_dict,
get_model_state_dict,
get_optim_state_dict,
get_resumption_state_dict,
)

__all__ = [
'get_model_state_dict',
'get_optim_state_dict',
'get_metadata_state_dict',
'get_resumption_state_dict',
]
135 changes: 132 additions & 3 deletions composer/checkpoint/state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,32 @@
import fnmatch
import logging
import sys
from typing import Any, Dict, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Union

from torch.utils.data import DataLoader, Dataset

from composer.core.data_spec import DataSpec

if TYPE_CHECKING:
from composer.core.evaluator import Evaluator

import torch
from packaging import version
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader

from composer.core.evaluator import Evaluator
from composer.core.state import State
from composer.core.time import Timestamp
from composer.devices import Device
from composer.models import ComposerModel, HuggingFaceModel
from composer.utils import STR_TO_DTYPE, dist, get_composer_env_dict
from composer.utils import STR_TO_DTYPE, dist, get_composer_env_dict, reproducibility

log = logging.getLogger(__name__)

__all__ = ['get_model_state_dict', 'get_optim_state_dict']
__all__ = ['get_model_state_dict', 'get_optim_state_dict', 'get_metadata_state_dict', 'get_resumption_state_dict']


def get_model_state_dict(
Expand Down Expand Up @@ -158,6 +169,93 @@ def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_st
return model_state_dict


def get_resumption_state_dict(state: State) -> Dict[str, Any]:
"""Generate the state dict for any objects needed for resumption.
This includes:
* timestamp
* scheduler
* dataset_state
* scaler
* rank_zero_seed
* callbacks
* algorithms
Returns:
The state dict containing the objects needed for resumption.
"""
resumption_state_dict = {}
resumption_state_dict['dataset_state'] = get_dataset_state_dict(
state.train_dataloader,
state.timestamp,
)
resumption_state_dict['timestamp'] = state.timestamp.state_dict()

scheduler_state_dict = _make_state_dict_for_list_of_objects(state.schedulers)
if scheduler_state_dict != {}:
resumption_state_dict['schedulers'] = scheduler_state_dict

# Use list of tuples to account for duplicates
callbacks_state_dict = _make_state_dict_for_list_of_objects(state.callbacks, use_list_of_tuples=True)
if callbacks_state_dict != {}:
resumption_state_dict['callbacks'] = callbacks_state_dict

# Use list of tuples to preserve order.
algorithms_state_dict = _make_state_dict_for_list_of_objects(state.algorithms, use_list_of_tuples=True)
if algorithms_state_dict != {}:
resumption_state_dict['algorithms'] = algorithms_state_dict

if state.scaler is not None:
scaler_sd = _make_state_dict_for_list_of_objects(state.scaler)
if scaler_sd != {}:
resumption_state_dict['scaler'] = state.scaler.state_dict()

resumption_state_dict['rank_zero_seed'] = state.rank_zero_seed
resumption_state_dict['run_name'] = state.run_name
resumption_state_dict['rng'] = reproducibility.get_rng_state()

return resumption_state_dict


def _make_state_dict_for_list_of_objects(objects: Union[Sequence[Any], Any],
use_list_of_tuples=False) -> Union[Dict[str, Any], List]:
object_list = []
object_dict = {}
if not isinstance(objects, Sequence):
objects = [objects]
for obj in objects:
if not hasattr(obj, 'state_dict') or obj.state_dict() == {}:
continue
if use_list_of_tuples:
object_list.append((type(obj).__qualname__, obj.state_dict()))
else:
object_dict[type(obj).__qualname__] = obj.state_dict()
if use_list_of_tuples:
return object_list
else:
return object_dict


def get_dataset_state_dict(
train_dataloader: Optional[Union[DataLoader, Iterable]],
timestamp: Timestamp,
) -> Dict[str, Any]:
"""Collect the state dict(s) of our train and eval dataset(s).
Returns:
Dict[str, Any]: The state dict(s).
"""
dataset_state_dict = {
'train': None,
}
dataset = _dataset_of(train_dataloader)
if hasattr(dataset, 'state_dict'):
num_samples = int(timestamp.sample_in_epoch.value)
dataset_state_dict['train'] = dataset.state_dict(num_samples, True) # pyright: ignore

return dataset_state_dict


def _get_optim_state_dict_with_fsdp_context_manager(
model: nn.Module,
optimizer: torch.optim.Optimizer,
Expand Down Expand Up @@ -356,3 +454,34 @@ def get_metadata_state_dict(
metadata_state_dict['precision'] = 'fp32'

return metadata_state_dict


def _dataset_of(dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]:
"""Get the dataset contained by the given dataloader-like object.
Args:
dataloader (Evaluator | DataSpec | DataLoader | Iterable, optional): The dataloader, wrapped dataloader, or
generic python iterable to get the dataset of, if applicable.
Returns:
Dataset: Its dataset, if there is one.
"""
from composer.core.evaluator import Evaluator

# If it's None, no dataset for you.
if dataloader is None:
return None

# An Evaluator is a dataloader wrapped with metrics. Unwrap its dataloader.
if isinstance(dataloader, Evaluator):
dataloader = dataloader.dataloader

# A DataSpec is a dataloader wrapped with an on-device transform. Unwrap its dataloader.
if isinstance(dataloader, DataSpec):
dataloader = dataloader.dataloader

# If what we now have is an actual DataLoader, return its dataset. If not, return None.
if isinstance(dataloader, DataLoader):
return dataloader.dataset
else:
return None
126 changes: 118 additions & 8 deletions tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
# Copyright 2024 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import datetime
from typing import Any, Dict
from unittest.mock import MagicMock

import pytest
import torch
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.optim import adam

from composer.checkpoint import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict
from composer.devices import DeviceGPU
from composer.utils import dist
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader

from composer.algorithms import SWA
from composer.callbacks import SpeedMonitor
from composer.checkpoint import (
get_metadata_state_dict,
get_model_state_dict,
get_optim_state_dict,
get_resumption_state_dict,
)
from composer.core import State
from composer.devices import DeviceCPU, DeviceGPU
from composer.utils import dist, reproducibility
from tests.common.compare import deep_compare
from tests.common.markers import world_size
from tests.common.models import EvenSimplerMLP, SimpleComposerMLP, configure_tiny_gpt2_hf_model
Expand Down Expand Up @@ -242,6 +254,7 @@ def _init_model_and_optimizer(
take_step=True,
use_fsdp=False,
tensor_type='sharded_tensor',
device='cuda',
):
model, loss_fn = _init_model(
use_composer_model,
Expand All @@ -250,6 +263,7 @@ def _init_model_and_optimizer(
num_features=num_features,
use_fsdp=use_fsdp,
tensor_type=tensor_type,
device=device,
)

optimizer = _init_optimizer(
Expand All @@ -260,6 +274,7 @@ def _init_model_and_optimizer(
batch_size=batch_size,
num_features=num_features,
take_step=take_step,
device=device,
)

return model, optimizer
Expand All @@ -271,13 +286,14 @@ def _init_model(
batch_size=5,
num_features=8,
use_fsdp=False,
device='cuda',
tensor_type='sharded_tensor',
):
if use_composer_model:
model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device='cuda')
model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device=device)
loss_fn = model._loss_fn
else:
model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device='cuda')
model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device=device)
loss_fn = torch.nn.CrossEntropyLoss()

if use_fsdp:
Expand Down Expand Up @@ -307,9 +323,10 @@ def _init_optimizer(
batch_size=5,
num_features=8,
take_step=True,
device='cuda',
):
inputs = torch.randn(batch_size, num_features, device='cuda')
targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device='cuda', dtype=torch.long)
inputs = torch.randn(batch_size, num_features, device=device)
targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device, dtype=torch.long)
batch = (inputs, targets) if use_composer_model else inputs
optimizer = adam.Adam(model.parameters())
outputs = model(batch)
Expand Down Expand Up @@ -514,3 +531,96 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz

assert 'dist_backend' in metadata_sd
assert metadata_sd['dist_backend'] == 'nccl'


@pytest.mark.filterwarnings('ignore:SWA has')
def test_get_resumption_state_dict():

model, optimizer = _init_model_and_optimizer(use_composer_model=True, take_step=True, device='cpu')

rank_zero_seed = 10
run_name = 'test_run'
device = DeviceCPU()
test_dataset_sd = {'foo': 0}
dataloader = MagicMock(spec=DataLoader)
dataloader.dataset = MagicMock()
dataloader.dataset.state_dict = MagicMock(return_value=test_dataset_sd)
swa = SWA()
state = State(
model=model,
rank_zero_seed=rank_zero_seed,
run_name=run_name,
device=device,
train_dataloader=dataloader,
algorithms=[swa],
callbacks=[SpeedMonitor(), SpeedMonitor()],
)
state.schedulers = StepLR(optimizer=optimizer, step_size=2)
rsd = get_resumption_state_dict(state)

assert rsd['rank_zero_seed'] == rank_zero_seed
assert rsd['run_name'] == run_name
assert 'timestamp' in rsd
assert rsd['timestamp'] == {
'iteration': 0,
'epoch': 0,
'batch': 0,
'sample': 0,
'token': 0,
'epoch_in_iteration': 0,
'batch_in_epoch': 0,
'sample_in_epoch': 0,
'token_in_epoch': 0,
'total_wct': datetime.timedelta(0),
'iteration_wct': datetime.timedelta(0),
'epoch_wct': datetime.timedelta(0),
'batch_wct': datetime.timedelta(0),
}
assert rsd['dataset_state'] == {'train': test_dataset_sd}
dict(rsd['algorithms'])['SWA'].pop('repr')
assert rsd['algorithms'] == [
(
'SWA',
{
'swa_model': None,
'swa_completed': False,
'swa_started': False,
'swa_scheduler': None,
'step_counter': 0,
},
),
]
assert rsd['callbacks'] == [('SpeedMonitor', {'total_eval_wct': 0.0}), ('SpeedMonitor', {'total_eval_wct': 0.0})]


@pytest.mark.gpu
def test_get_resumption_state_dict_gpu():
if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler
else:
from torch.cuda.amp.grad_scaler import GradScaler

model, _ = _init_model_and_optimizer(use_composer_model=True, take_step=False, device='cuda')

rank_zero_seed = 10
run_name = 'test_run'
device = DeviceCPU()
test_dataset_sd = {'test': 0}
dataloader = MagicMock()
dataloader.dataset = MagicMock()
dataloader.dataset.state_dict = MagicMock(return_value=test_dataset_sd)
state = State(
model=model,
rank_zero_seed=rank_zero_seed,
run_name=run_name,
device=device,
scaler=GradScaler(),
)
rsd = get_resumption_state_dict(state)
assert 'scaler' in rsd
assert set(
rsd['scaler'].keys(),
) == {'scale', 'growth_factor', 'backoff_factor', 'growth_interval', '_growth_tracker'}

assert 'rng' in rsd
deep_compare(rsd['rng'], reproducibility.get_rng_state())
2 changes: 0 additions & 2 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from composer.utils.object_store import S3ObjectStore
from composer.utils.reproducibility import get_rng_state
from tests.common import RandomClassificationDataset, deep_compare
from tests.common.compare import deep_compare
from tests.common.markers import world_size
from tests.trainer.test_checkpoint import TestCheckpointResumption, _assert_checkpoints_equivalent

Expand Down Expand Up @@ -188,7 +187,6 @@ def _compare_optims_between_state_dicts(state_dict1, state_dict2):
state_dict1_moment = state_dict1_moment.to_local()
if isinstance(state_dict2_moment, DTensor):
state_dict2_moment = state_dict2_moment.to_local()
print(param_name, state_dict1_moment, state_dict2_moment)
torch.testing.assert_close(state_dict1_moment, state_dict2_moment)


Expand Down

0 comments on commit d0712dc

Please sign in to comment.