Skip to content

Commit

Permalink
Merge pull request #208 from BerndDoser/mnist
Browse files Browse the repository at this point in the history
MNIST DataModule with random rotations
  • Loading branch information
BerndDoser committed Sep 5, 2024
2 parents fb78432 + 17ebeda commit 1b478ab
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/spherinator/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .illustris_sdss_dataset_with_metadata import IllustrisSdssDatasetWithMetadata
from .images_data_module import ImagesDataModule
from .images_dataset import ImagesDataset
from .mnist_data_module import MNISTDataModule
from .shapes_data_module import ShapesDataModule
from .shapes_dataset import ShapesDataset
from .spherinator_data_module import SpherinatorDataModule
Expand All @@ -22,6 +23,7 @@
"IllustrisSdssDatasetWithMetadata",
"ImagesDataModule",
"ImagesDataset",
"MNISTDataModule",
"ShapesDataModule",
"ShapesDataset",
"SpherinatorDataModule",
Expand Down
81 changes: 81 additions & 0 deletions src/spherinator/data/mnist_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import lightning.pytorch as pl
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST


class MNISTDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str = "./data/",
random_rotation: bool = True,
batch_size: int = 32,
num_workers: int = 4,
):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers

transformations = [
transforms.ToTensor(),
transforms.Pad((0, 0, 1, 1), fill=0),
transforms.Resize((87, 87)),
]
if random_rotation:
transformations += [transforms.RandomAffine(degrees=[0, 360])]
transformations += [
transforms.Resize((29, 29)),
transforms.Lambda(
lambda x: (x - x.min()) / (x.max() - x.min())
), # Normalize to [0, 1]
]
self.transform = transforms.Compose(transformations)

def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)

def setup(self, stage: str):
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

if stage == "test":
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)

if stage == "predict":
self.mnist_predict = MNIST(
self.data_dir, train=False, transform=self.transform
)

def train_dataloader(self):
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)

def val_dataloader(self):
return DataLoader(
self.mnist_val,
batch_size=self.batch_size,
num_workers=self.num_workers,
)

def test_dataloader(self):
return DataLoader(
self.mnist_test,
batch_size=self.batch_size,
num_workers=self.num_workers,
)

def predict_dataloader(self):
return DataLoader(
self.mnist_predict,
batch_size=self.batch_size,
num_workers=self.num_workers,
)
29 changes: 29 additions & 0 deletions tests/test_mnist_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np
import torch

from spherinator.data import MNISTDataModule


def test_fit(tmp_path):
data = MNISTDataModule(tmp_path, random_rotation=True, num_workers=1, batch_size=2)
data.prepare_data()
data.setup("fit")

assert len(data.mnist_train) == 55000

dataloader = data.train_dataloader()

assert dataloader.batch_size == 2
assert len(dataloader) == 27500
assert dataloader.num_workers == 1

batch = next(iter(dataloader))
images, labels = batch

assert images.shape == (2, 1, 29, 29)
assert images.dtype == torch.float32

assert np.isclose(images.min(), 0.0)
assert np.isclose(images.max(), 1.0)

assert (labels == torch.Tensor([1, 3])).all()

0 comments on commit 1b478ab

Please sign in to comment.