Skip to content

Commit

Permalink
use ReplicatedDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ez2rok committed Sep 20, 2024
1 parent f022b21 commit 3257430
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 142 deletions.
2 changes: 0 additions & 2 deletions composer/models/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,4 @@ def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
def forward(self, batch: tuple[Tensor, Any]) -> Tensor:
inputs, _ = batch
outputs = self.module(inputs)
from icecream import ic
ic(inputs, outputs)
return outputs
150 changes: 10 additions & 140 deletions tests/trainer/test_tp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Any, Optional, Sequence, TypeVar

E = TypeVar('E', bound=BaseException)
Expand All @@ -10,11 +9,10 @@
import pytest
import torch
from packaging import version
import torch.distributed as tdist
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed._tensor import DTensor, Replicate
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.utils.data import DataLoader, Dataset

from composer.callbacks import MemoryMonitor
from composer.loggers import InMemoryLogger
Expand All @@ -27,8 +25,6 @@
world_size,
)

from icecream import ic


class RandomClassificationDatasetReplicated(Dataset):
"""Like RandomClassificationDataset but samples are replicated across TP groups.
Expand Down Expand Up @@ -72,7 +68,6 @@ def __len__(self):
def __getitem__(self, idx):
if self.x is None and self.y is None:
self._generate_data()
ic(self.x, self.y)

assert self.x is not None
assert self.y is not None
Expand All @@ -81,72 +76,6 @@ def __getitem__(self, idx):
return self.x[rank_idx], self.y[rank_idx]


class CustomDistributedSampler(Sampler):
def __init__(
self,
dataset,
num_replicas=None,
rank=None,
shuffle=True,
seed=0,
drop_last=False,
replication=0,
):
num_replicas = dist.get_world_size()
rank = dist.get_local_rank()

self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
self.shuffle = shuffle
self.seed = seed
self.replication = replication

self.tensor_parallelism_group = self.rank // self.replication
self.tensor_parallelism_id = self.rank % self.replication
ic(self.tensor_parallelism_group, self.tensor_parallelism_id)

# Calculate the number of samples per tensor parallelism group
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas

# Adjust num_samples and total_size to ensure consistency across tensor parallelism groups
self.tp_group_size = self.replication
self.num_samples = int(math.ceil(self.num_samples * 1.0 / self.tp_group_size)) * self.tp_group_size
self.total_size = self.num_samples * self.num_replicas // self.tp_group_size

def __iter__(self):
indices = list(range(len(self.dataset)))

if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]

assert len(indices) == self.total_size

# Subsample based on rank, but ensure consistency across tensor parallelism groups
tp_group_rank = dist.get_rank(self.tensor_parallelism_group)
indices = indices[tp_group_rank:self.total_size:self.tp_group_size]
assert len(indices) == self.num_samples // self.tp_group_size

return iter(indices)

def __len__(self):
return self.num_samples // self.tp_group_size

def set_epoch(self, epoch):
self.epoch = epoch


def get_trainer(
parallelism_config: Optional[ParallelismConfig] = None,
size: int = 4,
Expand All @@ -161,16 +90,17 @@ def get_trainer(

reproducibility.seed_all(seed)

dataset: Dataset = RandomClassificationDataset(
dataset: Dataset = RandomClassificationDatasetReplicated(
shape=(num_features,),
num_classes=num_classes,
size=size,
device=device,
replication=replication,
) # X=(num_features,), y=(,), i.e. scalar

dataloader = DataLoader(
dataset,
sampler=CustomDistributedSampler(dataset, replication=replication),
sampler=dist.get_sampler(dataset),
batch_size=batch_size,
) # X=(batch_size, num_features), y=(batch_size,)

Expand Down Expand Up @@ -425,58 +355,6 @@ def test_tp_forwards_backwards(world_size: int, replication: int):
compare_models(ddp_trainer, fsdp_trainer, tp_fsdp_trainer)



class RandomClassificationDatasetReplicated2(Dataset):
"""Like RandomClassificationDataset but samples are replicated across TP groups.
Args:
shape (Sequence[int]): shape of features (default: (1, 1, 1))
size (int): number of samples (default: 100)
num_classes (int): number of classes (default: 2)
"""

def __init__(
self,
shape: Sequence[int] = (1, 1, 1),
size: int = 100,
num_classes: int = 2,
device: Optional[torch.device] = None,
seed: int = 44,
replication: int = 2,
):
self.size = size
self.shape = shape
self.num_classes = num_classes
self.device = device
self.replication = replication # the number of tp groups that we are replicating across
self.seed = seed
self.x: Optional[torch.Tensor] = None
self.y: Optional[torch.Tensor] = None

self.rank = dist.get_local_rank()
self.tp_group = self.rank // self.replication + 1
self.tp_idx = self.rank % self.replication

def __len__(self):
return self.size

def __getitem__(self, idx):
if self.x is None and self.y is None:
reproducibility.seed_all(self.seed)
self.x = torch.randn(self.size, *self.shape, device=self.device)
self.y = torch.randint(0, self.num_classes, size=(self.size,), device=self.device)
ic(self.x, self.y)

assert self.x is not None
assert self.y is not None


offset = idx % (self.tp_group * self.replication)
rank_idx = idx // self.replication + offset
ic(idx, self.tp_group, self.tp_idx, offset, rank_idx)
return self.x[rank_idx], self.y[rank_idx]


@pytest.mark.gpu
@world_size(4)
@pytest.mark.parametrize('replication', [2])
Expand All @@ -493,27 +371,24 @@ def test_tp_fit(world_size: int, batch_size: int, replication: int):

# Initialize number of samples in the dataset
# train_steps = 20 # number of steps to train for
train_steps = 2
train_steps = 20
samples_per_batch = world_size * batch_size // replication
n_samples = samples_per_batch * train_steps
dataset_size = samples_per_batch * train_steps

# DDP fit
ic('DDP')
ddp_trainer = get_ddp_trainer(size=n_samples, batch_size=batch_size, replication=replication)
ddp_trainer = get_ddp_trainer(size=dataset_size, batch_size=batch_size, replication=replication)
ddp_trainer.fit()
ddp_trainer.close()
ddp_stats = get_stats(ddp_trainer)

# FSDP fit
ic('FSDP')
fsdp_trainer = get_fsdp_trainer(size=n_samples, batch_size=batch_size, replication=replication)
fsdp_trainer = get_fsdp_trainer(size=dataset_size, batch_size=batch_size, replication=replication)
fsdp_trainer.fit()
fsdp_trainer.close()
fsdp_stats = get_stats(fsdp_trainer)

# TP-FSDP fit
ic('TP-FSDP')
tp_fsdp_trainer = get_tp_fsdp_trainer(size=n_samples, batch_size=batch_size, replication=replication)
tp_fsdp_trainer = get_tp_fsdp_trainer(size=dataset_size, batch_size=batch_size, replication=replication)
tp_fsdp_trainer.fit()
tp_fsdp_trainer.close()
tp_fsdp_stats = get_stats(tp_fsdp_trainer)
Expand Down Expand Up @@ -674,8 +549,3 @@ def test_tp_fsdp_state_dict(world_size: int):
tp_fsdp_state_dict1 = tp_fsdp_trainer.state.state_dict() # work sometimes, fails sometimes
with FSDP.summon_full_params(tp_fsdp_trainer.state.model, with_grads=True):
tp_fsdp_state_dict2 = tp_fsdp_trainer.state.state_dict() # fails always



if __name__ == '__main__':
test_tp_fit(4, 4, 2)

0 comments on commit 3257430

Please sign in to comment.