Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run TP Checkpoint Tests #3600

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
218 changes: 218 additions & 0 deletions TEST.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations
import os

import contextlib
import dataclasses
import os
import pathlib
import textwrap
import uuid
from contextlib import nullcontext as does_not_raise
from functools import partial
from typing import Any, Callable, Optional, Sequence, Union
from unittest.mock import patch

import numpy as np
import pytest
import torch
from packaging import version
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor
from torch.utils.data import DataLoader
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import MulticlassAccuracy

from composer.algorithms import EMA
from composer.core import Algorithm, Event, Precision, State, Time
from composer.core.state import fsdp_get_optim_state_dict, fsdp_state_dict_type_context
from composer.models import ComposerClassifier
from composer.optim import DecoupledAdamW
from composer.trainer import Trainer
from composer.utils import FSDPConfig, TPConfig, dist, parse_uri
from composer.utils.checkpoint import dist_cp_load, is_checkpoint_legacy_sharded
from composer.utils.file_helpers import get_file
from composer.utils.object_store import S3ObjectStore
from composer.utils.reproducibility import get_rng_state
from tests.common import RandomClassificationDataset, deep_compare
from tests.common.markers import world_size
from tests.trainer.test_checkpoint import TestCheckpointResumption, _assert_checkpoints_equivalent

from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from tests.trainer.test_fsdp_checkpoint import _compare_model_params_between_state_dicts, _compare_optims_between_state_dicts, _compare_metrics_between_state_dicts
from icecream import install
from icecream import ic

install()
ic.configureOutput(includeContext=True)

def test_1(use_tp: bool):
from tests.trainer.test_fsdp_checkpoint import get_trainer

tmp_path: pathlib.Path = 'tmp'
autoresume: bool = False
precision: str = 'amp_bf16'
optimizer: str = 'adam'
save_weights_only: bool = False
load_weights_only: bool = False
load_monolith_rank0_only: bool = False
use_hsdp: bool = False

if use_hsdp and version.parse(torch.__version__) < version.parse('2.4.0'):
pytest.xfail('HSDP requires torch 2.4.0 or later')
if use_tp and version.parse(torch.__version__) < version.parse('2.4.0'):
pytest.skip('TP has full state dict issues before PyTorch 2.4.')
if autoresume:
run_name = 'my-cool-autoresume-run'
else:
run_name = None
save_folder = tmp_path
save_filename = 'rank{rank}.pt'

if use_hsdp:
fsdp_config = FSDPConfig(
sharding_strategy='HYBRID_SHARD',
sharded_ckpt_prefix_dir='ba{batch}',
data_parallel_shard_degree=2,
data_parallel_replicate_degree=2,
sync_module_states=True,
)
else:
fsdp_config = FSDPConfig(
sharded_ckpt_prefix_dir='ba{batch}',
sync_module_states=load_monolith_rank0_only,
load_monolith_rank0_only=load_monolith_rank0_only,
)
tp_config = None
if use_tp:
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
tp_config = {
'tensor_parallel_degree': 2,
'layer_plan': {
'module.0': ColwiseParallel(),
'module.2': RowwiseParallel(),
},
}

trainer1 = get_trainer(
save_folder=str(save_folder),
save_filename=save_filename,
run_name=run_name,
precision=precision,
autoresume=autoresume,
optimizer=optimizer,
fsdp_config=fsdp_config,
tp_config=tp_config,
)

if use_tp:
assert trainer1.state.tp_config is not None
assert isinstance(trainer1.state.tp_config, TPConfig)

ic('Before trainer 1 fit')
print('Before trainer 1 fit')
trainer1.fit()
print('After trainer 1 fit')
state_dict_from_trainer1 = trainer1.state.state_dict()
trainer1.close()
load_path = str(save_folder / pathlib.Path('rank{rank}.pt'))
trainer2 = get_trainer(
save_folder=str(save_folder),
save_filename=save_filename,
load_path=load_path,
run_name=run_name,
precision=precision,
autoresume=autoresume,
max_duration='4ba',
optimizer=optimizer,
fsdp_config=fsdp_config,
save_weights_only=save_weights_only,
load_weights_only=load_weights_only,
tp_config=tp_config,
)
state_dict_from_trainer2 = trainer2.state.state_dict()

if dist.get_global_rank() == 0:
_compare_model_params_between_state_dicts(
state_dict_from_trainer1,
state_dict_from_trainer2,
)
if not load_weights_only:
_compare_optims_between_state_dicts(
state_dict_from_trainer1,
state_dict_from_trainer2,
)
_compare_metrics_between_state_dicts(
state_dict_from_trainer1,
state_dict_from_trainer2,
)
# Continue to fit to make sure we can continue training.
trainer2.fit()
trainer2.close()


def test_2(use_tp: bool, state_dict_type: str):
from tests.trainer.test_fsdp_checkpoint import SimpleMLP

fsdp_config = FSDPConfig(sharded_ckpt_prefix_dir='ba{batch}', state_dict_type=state_dict_type)
tp_config = None
if use_tp:
tp_config = {
'tensor_parallel_degree': 2,
'layer_plan': {'module.0': ColwiseParallel(), 'module.2': RowwiseParallel()},
}
parallelism_config: dict[str, Union[FSDPConfig, dict[str, Any]]] = {'fsdp': fsdp_config}
if tp_config is not None:
parallelism_config['tp'] = tp_config

num_features: int = 4
num_classes: int = 2
model = SimpleMLP(num_features=num_features, num_classes=num_classes)
model.module.to('cpu')
dataset = RandomClassificationDataset(shape=(num_features,), num_classes=num_classes, size=128)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset), batch_size=8,)
optim = torch.optim.Adam(params=model.parameters())

trainer1 = Trainer(
model=model,
optimizers=optim,
train_dataloader=dataloader,
parallelism_config=parallelism_config,
save_folder='tmp',
max_duration='2ba',
save_interval='2ba',
save_filename='rank{rank}.pt',
precision='amp_bf16',
progress_bar=False,
log_to_console=False,
save_latest_filename='latest-rank{rank}.pt',
)

if use_tp:
assert trainer1.state.tp_config is not None
assert isinstance(trainer1.state.tp_config, TPConfig)

ic('Before trainer 1 fit')
print('Before trainer 1 fit')
trainer1.fit()
print('After trainer 1 fit')

if __name__ == '__main__':
test = test_2
verbose = False

if not verbose:
ic.disable()
os.environ['NCCL_DEBUG'] = 'WARN'
if verbose:
os.environ['NCCL_DEBUG'] = 'INFO'

print('*'*70, '\nuse_tp=True, state_dict_type=sharded\n', '*'*70)
test(use_tp=True, state_dict_type='sharded')
print('*'*70, '\nDone\n', '*'*70)

print('*'*70, '\nuse_tp=True, state_dict_type=full\n', '*'*70)
test(use_tp=True, state_dict_type='full')
print('*'*70, '\nDone\n', '*'*70)
4 changes: 4 additions & 0 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,12 @@ def fit_start(self, state: State, logger: Logger) -> None:
def batch_checkpoint(self, state: State, logger: Logger):
assert callable(self.save_interval)
if self.save_interval(state, Event.BATCH_CHECKPOINT) and self.last_checkpoint_batch != state.timestamp.batch:
ic('before _save_checkpoint')
self._save_checkpoint(
state,
logger,
)
ic('after _save_checkpoint')

def epoch_checkpoint(self, state: State, logger: Logger):
assert callable(self.save_interval)
Expand Down Expand Up @@ -472,12 +474,14 @@ def _save_checkpoint(self, state: State, logger: Logger):
# Store before saving so state_dict in checkpoint has reference to latest checkpoint (itself)
self.all_saved_checkpoints_to_timestamp[save_filename] = state.timestamp

ic('before checkpoint.save_checkpoint')
saved_path = checkpoint.save_checkpoint(
state=state,
filename=filename_with_placeholders,
weights_only=self.weights_only,
ignore_keys=self.ignore_keys,
)
ic('after checkpoint.save_checkpoint')
log.debug(f'Checkpoint locally saved to {saved_path}')

self.symlink_count += 1
Expand Down
5 changes: 5 additions & 0 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,11 @@ def run_event(
traces = self._run_algorithms(event)
else:
traces = self._run_algorithms(event)
ic('before _run_nonlogger_callbacks')
# Run callbacks first, so any log calls from a callback that are executed lazily
# get registered before they are flushed by the logger itself.
self._run_nonlogger_callbacks(event)
ic('after _run_nonlogger_callbacks')
self._run_loggers(event)

if event.is_before_event and duration_marker is not None:
Expand Down Expand Up @@ -479,6 +481,7 @@ def _run_callbacks(

for cb in callbacks:
marker = None
ic(event, cb)
if self.state.profiler is not None:
marker = self.state.profiler.marker(
f'callback/{cb.__class__.__name__}/event/{event.value}',
Expand All @@ -490,7 +493,9 @@ def _run_callbacks(
ctx = cast(ContextManager, contextlib.nullcontext()) if marker is None else marker
with ctx:
self._debug_log(event, f'Running callback {type(cb).__name__}')
ic('before', event, cb)
cb.run_event(event, self.state, self.logger)
ic('after', event, cb)

def _run_loggers(self, event: Union[Event, str]):
loggers = [callback for callback in self.state.callbacks if isinstance(callback, LoggerDestination)]
Expand Down
7 changes: 6 additions & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ def get_model_state_dict(self) -> dict[str, Any]:
),
)

ic('before get_model_state_dict')
model_state_dict = get_model_state_dict(
model=self.model,
submodules=None,
Expand All @@ -984,6 +985,7 @@ def get_model_state_dict(self) -> dict[str, Any]:
cpu_offload=self.fsdp_enabled,
),
)
ic('after get_model_state_dict')
else:
if self.fsdp_enabled and self.fsdp_state_dict_type is not None:
with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type):
Expand All @@ -1006,6 +1008,7 @@ def get_optim_state_dict(self) -> dict[str, Any]:
if version.parse(torch.__version__) >= version.parse('2.4.0') or (
version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized()
):

from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand All @@ -1015,8 +1018,8 @@ def get_optim_state_dict(self) -> dict[str, Any]:
'fsdp_state_dict_type to None, "full", or "sharded".',
),
)

optimizer = ensure_tuple(self.optimizers)[0]
ic('before get_optimizer_state_dict')
optim_state_dict = get_optimizer_state_dict(
model=self.model,
optimizers=optimizer,
Expand All @@ -1026,6 +1029,7 @@ def get_optim_state_dict(self) -> dict[str, Any]:
cpu_offload=self.fsdp_enabled,
),
)
ic('after get_optimizer_state_dict')
return {type(optimizer).__qualname__: optim_state_dict}
else:
optimizer = ensure_tuple(self.optimizers)[0]
Expand All @@ -1046,6 +1050,7 @@ def state_dict(self) -> dict[str, Any]:
"""
state_dict = {}
for attribute_name in self.serialized_attributes:
ic(attribute_name)
attribute_value = getattr(self, attribute_name)
if attribute_name == 'dataset_state':
serialized_value = self._dataset_state_dict()
Expand Down
4 changes: 2 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,7 +2667,6 @@ def _train_loop(self) -> None:
self.state.batch = fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision)

self.engine.run_event(Event.AFTER_DATALOADER)

self.engine.run_event(Event.BATCH_START)

# Log time values
Expand Down Expand Up @@ -2736,8 +2735,9 @@ def _train_loop(self) -> None:
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.BATCH_END)
last_wct = datetime.datetime.now() - duration

ic('before self.engine.run_event(Event.BATCH_CHECKPOINT)')
self.engine.run_event(Event.BATCH_CHECKPOINT)
ic('after self.engine.run_event(Event.BATCH_CHECKPOINT)')

if (
self.state.timestamp >= self.state.max_duration or (
Expand Down
7 changes: 4 additions & 3 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def load_sharded_checkpoint(
num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
state_dict: dict[str, Any] = {
'state': cur_state_dict,
'rng': reproducibility.get_rng_state()[:num_rng_ranks],
'rng': 42 # reproducibility.get_rng_state()[:num_rng_ranks],
}

if ignore_keys:
Expand Down Expand Up @@ -1132,7 +1132,6 @@ def _save_checkpoint(
) -> Union[str, None]: # noqa: D103

is_deepspeed = is_model_deepspeed(state.model)

if weights_only and not is_deepspeed:
state_dict = {
'state': {
Expand All @@ -1142,10 +1141,12 @@ def _save_checkpoint(
},
}
else:
ic('before reproducibility.get_rng_state()')
state_dict = {
'state': state.state_dict(),
'rng': reproducibility.get_rng_state(),
'rng': reproducibility.get_rng_state(), # 42
}
ic('after reproducibility.get_rng_state()')

if ignore_keys:
# Filter provided list of key paths
Expand Down
3 changes: 3 additions & 0 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,10 @@ def all_gather_object(obj: TObj, group=None) -> list[TObj]:
if is_hpu_installed():
all_gather_object_list_hpu(obj_gather_list, obj, group=group)
else:
ic('before all_gather_object')
ic(obj_gather_list, obj.keys(), group)
dist.all_gather_object(obj_gather_list, obj, group=group)
ic('after all_gather_object')
# torch.distributed will replace the None's in obj_gather_list with the gathered objects on rank 0
# or will just be None on non-rank-0
return cast(list[TObj], obj_gather_list)
Expand Down
4 changes: 2 additions & 2 deletions composer/utils/reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def get_rng_state() -> list[dict[str, Any]]:
if torch.cuda.is_available() and torch.cuda.is_initialized():
# This will not be compatible with model parallelism
rng_state['cuda'] = torch.cuda.get_rng_state()

return dist.all_gather_object(rng_state)
ic('before dist.all_gather_object(rng_state)')
return ic(dist.all_gather_object(rng_state))


def load_rng_state(rng_state_dicts: list[dict[str, Any]]):
Expand Down
Loading
Loading