From 07196dd5b9f960f6ef1cf9ceea924e5dd6785c23 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Mon, 23 Sep 2024 21:50:31 +0000 Subject: [PATCH] better dataset --- tests/common/__init__.py | 2 ++ tests/common/datasets.py | 41 ++++++++++++++++++++++- tests/trainer/test_tp.py | 71 +++++++++------------------------------- 3 files changed, 57 insertions(+), 57 deletions(-) diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 89715ef34c..e843fe5d47 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -8,6 +8,7 @@ InfiniteClassificationDataset, ParityDataset, RandomClassificationDataset, + RandomClassificationDatasetReplicated, RandomImageDataset, RandomSegmentationDataset, RandomTextClassificationDataset, @@ -43,6 +44,7 @@ def get_module_subclasses(module: types.ModuleType, cls: type) -> list[type]: __all__ = [ 'assert_state_equivalent', 'RandomClassificationDataset', + 'RandomClassificationDatasetReplicated', 'RandomTextClassificationDataset', 'RandomTextLMDataset', 'RandomImageDataset', diff --git a/tests/common/datasets.py b/tests/common/datasets.py index 37a35af6b8..e91ce125e5 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from torchvision.datasets import VisionDataset -from composer.utils import dist +from composer.utils import dist, reproducibility from tests.common.models import configure_tiny_bert_tokenizer, configure_tiny_gpt2_tokenizer @@ -86,6 +86,45 @@ def __getitem__(self, index: int): return self.x[index], self.y[index] +class RandomClassificationDatasetReplicated(RandomClassificationDataset): + """Like RandomClassificationDataset but samples are replicated across tensor parallelism groups.""" + + def __init__( + self, + shape: Sequence[int] = (1, 1, 1), + size: int = 100, + num_classes: int = 2, + device: Optional[torch.device] = None, + seed: int = 44, + replication: Optional[int] = 2, + ): + super().__init__(shape, size, num_classes, device) + self.rank = dist.get_local_rank() + self.world_size = dist.get_world_size() + self.n_tp_groups = replication # the number of tp groups that we are replicating across + self.seed = seed + + def _generate_data(self): + tp_group_id = self.rank // self.n_tp_groups + seed = self.seed + tp_group_id # all ranks in the same TP group have the same seed + reproducibility.seed_all(seed) + self.x = torch.randn(self.size, *self.shape, device=self.device) + self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device) + + def __len__(self): + return self.size + + def __getitem__(self, idx): + if self.x is None and self.y is None: + self._generate_data() + + assert self.x is not None + assert self.y is not None + + rank_idx = idx // self.world_size + return self.x[rank_idx], self.y[rank_idx] + + class RandomImageDataset(VisionDataset): """ Image Classification dataset with values drawn from a normal distribution Args: diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index c63f518c65..374f111cd6 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional, Sequence, TypeVar, Union +from typing import Any, Optional, TypeVar, Union E = TypeVar('E', bound=BaseException) @@ -19,62 +19,13 @@ from composer.utils import FSDPConfig, ParallelismConfig, TPConfig, dist, reproducibility from tests.common import ( RandomClassificationDataset, + RandomClassificationDatasetReplicated, SimpleComposerMLP, SimpleModel, world_size, ) -class RandomClassificationDatasetReplicated(Dataset): - """Like RandomClassificationDataset but samples are replicated across TP groups. - - Args: - shape (Sequence[int]): shape of features (default: (1, 1, 1)) - size (int): number of samples (default: 100) - num_classes (int): number of classes (default: 2) - """ - - def __init__( - self, - shape: Sequence[int] = (1, 1, 1), - size: int = 100, - num_classes: int = 2, - device: Optional[torch.device] = None, - seed: int = 44, - replication: int = 2, - ): - self.size = size - self.shape = shape - self.num_classes = num_classes - self.device = device - self.rank = dist.get_local_rank() - self.world_size = dist.get_world_size() - self.n_tp_groups = replication # the number of tp groups that we are replicating across - self.seed = seed - self.x: Optional[torch.Tensor] = None - self.y: Optional[torch.Tensor] = None - - def _generate_data(self): - tp_group_id = self.rank // self.n_tp_groups - seed = self.seed + tp_group_id # all ranks in the same TP group have the same seed - reproducibility.seed_all(seed) - self.x = torch.randn(self.size, *self.shape, device=self.device) - self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device) - - def __len__(self): - return self.size - - def __getitem__(self, idx): - if self.x is None and self.y is None: - self._generate_data() - - assert self.x is not None - assert self.y is not None - - rank_idx = idx // self.world_size - return self.x[rank_idx], self.y[rank_idx] - - def get_trainer( parallelism_config: Optional[ParallelismConfig] = None, size: int = 4, @@ -83,7 +34,7 @@ def get_trainer( num_features: int = 2, seed: int = 44, device: Union[torch.device, str] = 'cuda', - replication: int = 0, + replication: Optional[int] = None, ): """Trainer for a simple model with any parallelism_config.""" @@ -131,7 +82,7 @@ def get_ddp_trainer( num_features: int = 2, seed: int = 44, device: Union[torch.device, str] = 'cuda', - replication: int = 0, + replication: Optional[int] = None, ): ddp_trainer = get_trainer( size=size, @@ -152,7 +103,7 @@ def get_fsdp_trainer( num_features: int = 2, seed: int = 44, device: Union[torch.device, str] = 'cuda', - replication: int = 0, + replication: Optional[int] = None, ): fsdp_config = FSDPConfig( state_dict_type='full', @@ -182,7 +133,7 @@ def get_tp_fsdp_trainer( num_features: int = 2, seed: int = 44, device: Union[torch.device, str] = 'cuda', - replication: int = 0, + replication: Optional[int] = None, ): from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel @@ -200,7 +151,7 @@ def get_tp_fsdp_trainer( tp_config = TPConfig( layer_plan=layer_plan, - tensor_parallel_degree=replication, + tensor_parallel_degree=1 if replication is None else replication, ) parallelism_config = ParallelismConfig(fsdp=fsdp_config, tp=tp_config) @@ -559,3 +510,11 @@ def test_tp_fsdp_state_dict(world_size: int): tp_fsdp_state_dict2 = tp_fsdp_trainer.state.state_dict() # fails always compare_modules(tp_fsdp_state_dict1['model'], tp_fsdp_state_dict2['model']) + + +if __name__ == '__main__': + world_size = 4 + replication = 2 + test_tp_forwards_backwards_correctness(world_size, replication) + + test_tp_fit_correctness(world_size, 4, replication)