Skip to content

Commit

Permalink
Revert TP integration (mosaicml#3328)
Browse files Browse the repository at this point in the history
* Revert "Bugfixes to FSDP + TP (mosaicml#3323)"

This reverts commit 79e79eb.

* Revert "Tensor Parallelism Integration (mosaicml#3269)"

This reverts commit 09f14f9.
  • Loading branch information
dakinggg committed May 25, 2024
1 parent 0be5d30 commit f154ea6
Show file tree
Hide file tree
Showing 32 changed files with 829 additions and 1,147 deletions.
5 changes: 2 additions & 3 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
is_deepspeed,
keep_placeholders=True,
).lstrip('/')
assert state.fsdp_config is not None
remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir']
assert remote_prefix is not None
assert state.sharded_ckpt_prefix_dir is not None
remote_prefix = state.sharded_ckpt_prefix_dir
ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename)
remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp)
Expand Down
215 changes: 57 additions & 158 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import torch
import torch.nn.modules.utils
from packaging import version
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
Expand All @@ -31,6 +30,8 @@
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric

from composer.utils.warnings import VersionedDeprecationWarning

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler # type: ignore
else:
Expand All @@ -43,8 +44,6 @@
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
from composer.devices import Device
from composer.utils import (
ParallelismType,
VersionedDeprecationWarning,
batch_get,
batch_set,
dist,
Expand Down Expand Up @@ -195,79 +194,6 @@ def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]):
return state


def _create_device_mesh(
device: Device,
fsdp_config: Optional[Dict[str, Any]],
tp_config: Optional[Dict[str, Any]],
) -> Optional[DeviceMesh]:
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
# Device mesh has correctness issues before torch 2.3.0
return None

if fsdp_config is None:
return None

# Gather dimensions and names for the device mesh
dims: List[int] = []
names: List[str] = []
if fsdp_config['data_parallel_replicate_degree'] != 1:
dims.append(fsdp_config['data_parallel_replicate_degree'])
names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
dims.append(fsdp_config['data_parallel_shard_degree'])
names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
if tp_config is not None:
dims.append(tp_config['tensor_parallel_degree'])
names.append(ParallelismType.TENSOR_PARALLEL.value)

# Fill in the unspecified dimensions
product_of_dims = 1
unspecified_dim_names = []
for dim, name in zip(dims, names):
if dim != -1:
product_of_dims *= dim
else:
unspecified_dim_names.append(name)
if len(unspecified_dim_names) > 1:
raise ValueError(
f'Found multiple parallelism dimensions with -1: {unspecified_dim_names}. '
'Only one is allowed, which is set to fill the remaining dimensions.',
)
elif len(unspecified_dim_names) == 1:
if product_of_dims > dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} is greater than the product of the specified parallelism degrees '
f'{product_of_dims}. Please ensure the product of the specified parallelism degrees matches the world ',
f'size. Currently specified degrees are {names=}, {dims=}. One dimension can also be left as -1, which '
'will automatically be specified to ensure the product matches the world size.',
)
remaining_dimension = dist.get_world_size() // product_of_dims
if remaining_dimension * product_of_dims != dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} is not divisible by the product of the specified '
'parallelism degrees. Please ensure the product of the specified parallelism degrees '
'matches the world size.',
)
for i, dim in enumerate(dims):
if dim == -1:
dims[i] = remaining_dimension
log.info(f'Automatically setting {names[i]} to have parallelization degree {remaining_dimension}.')
break
else:
if product_of_dims != dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} does not equal the product of the specified parallelism degrees '
f'{product_of_dims}. Please ensure the product of the specified parallelism degrees matches the world ',
f'size. Currently specified degrees are {names=}, {dims=}. One dimension can also be left as -1, which '
'will automatically be specified to ensure the product matches the world size.',
)

device_type = device.name
if device_type == 'gpu':
device_type = 'cuda'

return init_device_mesh(device_type=device_type, mesh_shape=tuple(dims), mesh_dim_names=tuple(names))


_STATE_DICT_SERIALIZED_ATTRIBUTES = [
# List of attributes that are serialized with state_dict
# Only the attributes listed in state.serialized_attributes will actually be saved.
Expand Down Expand Up @@ -329,7 +255,8 @@ class State(Serializable):
algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training.
callbacks (Callback | Sequence[Callback], optional): The callbacks used for training.
deepspeed_config (Dict[str, Any], optional): The configuration dictionary for deepspeed.
parallelism_config (Dict[str, Any], optional): The configuration dictionary for parallelism.
fsdp_config (Dict[str, Any], optional): The configuration dictionary for FSDP.
fsdp_auto_wrap (bool, optional): Whether to automatically wrap the model with FSDP.
Attributes:
batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a
Expand Down Expand Up @@ -496,7 +423,8 @@ def __init__(

# Distributed training configs
deepspeed_config: Optional[Dict[str, Any]] = None,
parallelism_config: Optional[Dict[str, Any]] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
fsdp_auto_wrap: bool = True,
):
self.rank_zero_seed = rank_zero_seed
self.model = model
Expand Down Expand Up @@ -540,88 +468,20 @@ def __init__(
self.profiler: Optional[Profiler] = None

self.deepspeed_config = deepspeed_config
parallelism_config = parallelism_config or {}
self.fsdp_config = parallelism_config.get('fsdp', None)
self.tp_config = parallelism_config.get('tp', None)

self._validate_parallelism_configs()

self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config)
if self.fsdp_config is not None and self.device_mesh is not None:
fsdp_mesh_dim_names = []
if self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names:
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
self.fsdp_config['device_mesh'] = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore
if self.tp_config is not None and self.device_mesh is not None:
self.tp_config['device_mesh'] = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value]

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]
self.fsdp_config = fsdp_config
self.fsdp_auto_wrap = fsdp_auto_wrap

self.train_metrics: Optional[Dict[str, Metric]] = {}
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _validate_parallelism_configs(self):
# Validate TP config
if self.tp_config is not None:
warnings.warn('Tensor parallelism (TP) is experimental and may change in future versions.', FutureWarning)
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
raise ValueError('Tensor parallelism (TP) requires torch>=2.3.0.')
if self.fsdp_config is None:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP to be enabled. '
'An empty `fsdp_config` can be specified to enable FSDP with '
'default settings. Additionally, PyTorch currently errors if FSDP '
'data_parallel_shard_degree is not at least 2.',
)
if not self.fsdp_config['use_orig_params']:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, '
'which is the default and recommended setting.',
)

# Load monolith rank0 only
if self.load_monolith_rank0_only:
if self.tp_config is not None:
raise ValueError('load_fsdp_monolith_rank0_only is not compatible with tensor parallelism (TP).')
assert self.fsdp_config is not None
assert fsdp_config is not None
error_message = ''
if self.fsdp_config['sync_module_states'] == False:
if fsdp_config['sync_module_states'] == False:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. "
"Either set fsdp_config['sync_module_states'] = True or set load_monolith_rank0_only = False. ",
)
# Broadcast rank 0 meta check to all ranks so error can be raised on all ranks
rank0_on_meta = 0
if dist.get_global_rank() == 0 and next(self.model.parameters()).device.type == 'meta':
if dist.get_global_rank() == 0 and next(model.parameters()).device.type == 'meta':
rank0_on_meta = 1
rank0_on_meta_tensor = self.device.tensor_to_device(torch.tensor([rank0_on_meta], dtype=torch.uint8))
dist.all_reduce(rank0_on_meta_tensor, reduce_operation='MAX')
Expand All @@ -634,7 +494,10 @@ def _validate_parallelism_configs(self):
if error_message != '':
raise ValueError(error_message)

# Validate FSDP state dict type
self.sharded_ckpt_prefix_dir: Optional[str] = None
if self.fsdp_config is not None:
self.sharded_ckpt_prefix_dir = self.fsdp_config['sharded_ckpt_prefix_dir']

if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
if self.fsdp_state_dict_type == 'local':
raise ValueError(
Expand All @@ -658,6 +521,39 @@ def _validate_parallelism_configs(self):
),
)

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]

self.train_metrics: Optional[Dict[str, Metric]] = {}
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _dataset_of(self, dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]:
"""Get the dataset contained by the given dataloader-like object.
Expand Down Expand Up @@ -898,8 +794,12 @@ def fsdp_sharded_state_dict_enabled(self):

@property
def fsdp_device_mesh(self):
warnings.warn(VersionedDeprecationWarning('fsdp_device_mesh is deprecated. Use device_mesh instead.', '0.24'))
return self.device_mesh
if self.fsdp_enabled:
if not hasattr(self.model, 'model') or not hasattr(self.model.model, '_device_mesh'):
return None
return self.model.model._device_mesh
else:
return None

@property
def load_fsdp_monolith_rank0_only(self):
Expand All @@ -914,8 +814,8 @@ def load_fsdp_monolith_rank0_only(self):
@property
def load_monolith_rank0_only(self):
return (
self.fsdp_config is not None and self.fsdp_config['auto_wrap'] and
self.fsdp_config['state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True
self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and
self.fsdp_config['load_monolith_rank0_only'] == True
)

def _get_integrations_state_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1389,9 +1289,8 @@ def load_model_state(
if self.load_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
with reproducibility.seed_context(self.rank_zero_seed):
from composer.distributed import prepare_fsdp_module

prepare_fsdp_module(
self.model,
self.optimizers,
Expand Down
25 changes: 0 additions & 25 deletions composer/distributed/__init__.py

This file was deleted.

Loading

0 comments on commit f154ea6

Please sign in to comment.