Skip to content

Commit

Permalink
better dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 23, 2024
1 parent 895f08e commit 07196dd
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 57 deletions.
2 changes: 2 additions & 0 deletions tests/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InfiniteClassificationDataset,
ParityDataset,
RandomClassificationDataset,
RandomClassificationDatasetReplicated,
RandomImageDataset,
RandomSegmentationDataset,
RandomTextClassificationDataset,
Expand Down Expand Up @@ -43,6 +44,7 @@ def get_module_subclasses(module: types.ModuleType, cls: type) -> list[type]:
__all__ = [
'assert_state_equivalent',
'RandomClassificationDataset',
'RandomClassificationDatasetReplicated',
'RandomTextClassificationDataset',
'RandomTextLMDataset',
'RandomImageDataset',
Expand Down
41 changes: 40 additions & 1 deletion tests/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchvision.datasets import VisionDataset

from composer.utils import dist
from composer.utils import dist, reproducibility
from tests.common.models import configure_tiny_bert_tokenizer, configure_tiny_gpt2_tokenizer


Expand Down Expand Up @@ -86,6 +86,45 @@ def __getitem__(self, index: int):
return self.x[index], self.y[index]


class RandomClassificationDatasetReplicated(RandomClassificationDataset):
"""Like RandomClassificationDataset but samples are replicated across tensor parallelism groups."""

def __init__(
self,
shape: Sequence[int] = (1, 1, 1),
size: int = 100,
num_classes: int = 2,
device: Optional[torch.device] = None,
seed: int = 44,
replication: Optional[int] = 2,
):
super().__init__(shape, size, num_classes, device)
self.rank = dist.get_local_rank()
self.world_size = dist.get_world_size()
self.n_tp_groups = replication # the number of tp groups that we are replicating across
self.seed = seed

def _generate_data(self):
tp_group_id = self.rank // self.n_tp_groups
seed = self.seed + tp_group_id # all ranks in the same TP group have the same seed
reproducibility.seed_all(seed)
self.x = torch.randn(self.size, *self.shape, device=self.device)
self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device)

def __len__(self):
return self.size

def __getitem__(self, idx):
if self.x is None and self.y is None:
self._generate_data()

assert self.x is not None
assert self.y is not None

rank_idx = idx // self.world_size
return self.x[rank_idx], self.y[rank_idx]


class RandomImageDataset(VisionDataset):
""" Image Classification dataset with values drawn from a normal distribution
Args:
Expand Down
71 changes: 15 additions & 56 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional, Sequence, TypeVar, Union
from typing import Any, Optional, TypeVar, Union

E = TypeVar('E', bound=BaseException)

Expand All @@ -19,62 +19,13 @@
from composer.utils import FSDPConfig, ParallelismConfig, TPConfig, dist, reproducibility
from tests.common import (
RandomClassificationDataset,
RandomClassificationDatasetReplicated,
SimpleComposerMLP,
SimpleModel,
world_size,
)


class RandomClassificationDatasetReplicated(Dataset):
"""Like RandomClassificationDataset but samples are replicated across TP groups.
Args:
shape (Sequence[int]): shape of features (default: (1, 1, 1))
size (int): number of samples (default: 100)
num_classes (int): number of classes (default: 2)
"""

def __init__(
self,
shape: Sequence[int] = (1, 1, 1),
size: int = 100,
num_classes: int = 2,
device: Optional[torch.device] = None,
seed: int = 44,
replication: int = 2,
):
self.size = size
self.shape = shape
self.num_classes = num_classes
self.device = device
self.rank = dist.get_local_rank()
self.world_size = dist.get_world_size()
self.n_tp_groups = replication # the number of tp groups that we are replicating across
self.seed = seed
self.x: Optional[torch.Tensor] = None
self.y: Optional[torch.Tensor] = None

def _generate_data(self):
tp_group_id = self.rank // self.n_tp_groups
seed = self.seed + tp_group_id # all ranks in the same TP group have the same seed
reproducibility.seed_all(seed)
self.x = torch.randn(self.size, *self.shape, device=self.device)
self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device)

def __len__(self):
return self.size

def __getitem__(self, idx):
if self.x is None and self.y is None:
self._generate_data()

assert self.x is not None
assert self.y is not None

rank_idx = idx // self.world_size
return self.x[rank_idx], self.y[rank_idx]


def get_trainer(
parallelism_config: Optional[ParallelismConfig] = None,
size: int = 4,
Expand All @@ -83,7 +34,7 @@ def get_trainer(
num_features: int = 2,
seed: int = 44,
device: Union[torch.device, str] = 'cuda',
replication: int = 0,
replication: Optional[int] = None,
):
"""Trainer for a simple model with any parallelism_config."""

Expand Down Expand Up @@ -131,7 +82,7 @@ def get_ddp_trainer(
num_features: int = 2,
seed: int = 44,
device: Union[torch.device, str] = 'cuda',
replication: int = 0,
replication: Optional[int] = None,
):
ddp_trainer = get_trainer(
size=size,
Expand All @@ -152,7 +103,7 @@ def get_fsdp_trainer(
num_features: int = 2,
seed: int = 44,
device: Union[torch.device, str] = 'cuda',
replication: int = 0,
replication: Optional[int] = None,
):
fsdp_config = FSDPConfig(
state_dict_type='full',
Expand Down Expand Up @@ -182,7 +133,7 @@ def get_tp_fsdp_trainer(
num_features: int = 2,
seed: int = 44,
device: Union[torch.device, str] = 'cuda',
replication: int = 0,
replication: Optional[int] = None,
):
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel

Expand All @@ -200,7 +151,7 @@ def get_tp_fsdp_trainer(

tp_config = TPConfig(
layer_plan=layer_plan,
tensor_parallel_degree=replication,
tensor_parallel_degree=1 if replication is None else replication,
)
parallelism_config = ParallelismConfig(fsdp=fsdp_config, tp=tp_config)

Expand Down Expand Up @@ -559,3 +510,11 @@ def test_tp_fsdp_state_dict(world_size: int):
tp_fsdp_state_dict2 = tp_fsdp_trainer.state.state_dict() # fails always

compare_modules(tp_fsdp_state_dict1['model'], tp_fsdp_state_dict2['model'])


if __name__ == '__main__':
world_size = 4
replication = 2
test_tp_forwards_backwards_correctness(world_size, replication)

test_tp_fit_correctness(world_size, 4, replication)

0 comments on commit 07196dd

Please sign in to comment.