Skip to content

Commit

Permalink
Merge pull request #171 from BerndDoser/rot-loss
Browse files Browse the repository at this point in the history
Alternative rotational invariance and convolution neural networks
  • Loading branch information
BerndDoser committed Jul 24, 2024
2 parents e165248 + 574960f commit 416e2f6
Show file tree
Hide file tree
Showing 33 changed files with 1,349 additions and 704 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ __pycache__
.venv*/
*.ckpt
*.ncu-rep
*.sarif
config.yaml
dist/
docs/_build/
docs/html/
HiPSter/
lightning_logs/
wandb/
docs/_build/
docs/html/
66 changes: 66 additions & 0 deletions experiments/pokemon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
seed_everything: 42

model:
class_path: spherinator.models.RotationalVariationalAutoencoderPower
init_args:
encoder:
class_path: spherinator.models.ConvolutionalEncoder
decoder:
class_path: spherinator.models.ConvolutionalDecoder
h_dim: 256
z_dim: 3
image_size: 224
rotations: 1
beta: 1.0e-3

data:
class_path: spherinator.data.ImagesDataModule
init_args:
data_directory: /local_data/doserbd/data/pokemon
extensions: ['jpg']
image_size: 224
batch_size: 32
shuffle: True
num_workers: 16

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1.e-3

lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: min
factor: 0.1
patience: 5
cooldown: 5
min_lr: 1.e-6
monitor: train_loss
verbose: True

trainer:
max_epochs: -1
accelerator: gpu
devices: [3]
precision: 32
callbacks:
- class_path: spherinator.callbacks.LogReconstructionCallback
init_args:
num_samples: 6
# - class_path: lightning.pytorch.callbacks.ModelCheckpoint
# init_args:
# monitor: train_loss
# filename: "{epoch}-{train_loss:.2f}"
# save_top_k: 3
# mode: min
# every_n_epochs: 1
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
log_model: True
entity: ain-space
tags:
- rot-loss
- pokemon
1,049 changes: 531 additions & 518 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/spherinator/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .illustris_sdss_data_module import IllustrisSdssDataModule
from .illustris_sdss_dataset import IllustrisSdssDataset
from .illustris_sdss_dataset_with_metadata import IllustrisSdssDatasetWithMetadata
from .images_data_module import ImagesDataModule
from .images_dataset import ImagesDataset
from .shapes_data_module import ShapesDataModule
from .shapes_dataset import ShapesDataset
from .spherinator_data_module import SpherinatorDataModule
Expand All @@ -18,6 +20,8 @@
"IllustrisSdssDataModule",
"IllustrisSdssDataset",
"IllustrisSdssDatasetWithMetadata",
"ImagesDataModule",
"ImagesDataset",
"ShapesDataModule",
"ShapesDataset",
"SpherinatorDataModule",
Expand Down
2 changes: 1 addition & 1 deletion src/spherinator/data/galaxy_zoo_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def setup(self, stage: str):
Args:
stage (str): Defines for which stage the data is needed.
"""
if not stage in ["fit", "processing", "images", "thumbnail_images"]:
if stage not in ["fit", "processing", "images", "thumbnail_images"]:
raise ValueError(f"Stage {stage} not supported.")

if stage == "fit" and self.data_train is None:
Expand Down
2 changes: 1 addition & 1 deletion src/spherinator/data/illustris_sdss_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def setup(self, stage: str):
Args:
stage (str): Defines for which stage the data is needed.
"""
if not stage in ["fit", "processing", "images", "thumbnail_images"]:
if stage not in ["fit", "processing", "images", "thumbnail_images"]:
raise ValueError(f"Stage {stage} not supported.")

if stage == "fit" and self.data_train is None:
Expand Down
85 changes: 85 additions & 0 deletions src/spherinator/data/images_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
import torchvision.transforms.v2 as transforms
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader

from spherinator.data.images_dataset import ImagesDataset


class ImagesDataModule(LightningDataModule):
"""Defines access to the ImagesDataset."""

def __init__(
self,
data_directory: str,
extensions: list[str] = ["jpg"],
shuffle: bool = True,
image_size: int = 64,
batch_size: int = 32,
num_workers: int = 1,
):
"""Initializes the data loader
Args:
data_directory (str): The data directory
shuffle (bool, optional): Wether or not to shuffle whe reading. Defaults to True.
image_size (int, optional): The size of the images. Defaults to 64.
batch_size (int, optional): The batch size for training. Defaults to 32.
num_workers (int, optional): How many worker to use for loading. Defaults to 1.
download (bool, optional): Wether or not to download the data. Defaults to False.
"""
super().__init__()

self.data_directory = data_directory
self.extensions = extensions
self.shuffle = shuffle
self.image_size = image_size
self.batch_size = batch_size
self.num_workers = num_workers

self.data_train = None
self.dataloader_train = None

self.transform_train = transforms.Compose(
[
transforms.Resize((self.image_size, self.image_size), antialias=True),
transforms.Lambda( # Normalize
lambda x: (x - torch.min(x)) / (torch.max(x) - torch.min(x))
),
]
)
self.transform_processing = self.transform_train
self.transform_images = self.transform_train
self.transform_thumbnail_images = transforms.Compose(
[
self.transform_train,
transforms.Resize((100, 100), antialias=True),
]
)

def setup(self, stage: str):
"""Sets up the data set and data loaders.
Args:
stage (str): Defines for which stage the data is needed.
For the moment just fitting is supported.
"""

if stage == "fit" and self.data_train is None:
self.data_train = ImagesDataset(
data_directory=self.data_directory,
extensions=self.extensions,
transform=self.transform_train,
)
self.dataloader_train = DataLoader(
self.data_train,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)
else:
raise ValueError(f"Stage {stage} not supported.")

def train_dataloader(self):
"""Gets the data loader for training."""
return self.dataloader_train
67 changes: 67 additions & 0 deletions src/spherinator/data/images_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
""" Create dataset with all image files in a directory.
"""

import os
from pathlib import Path

import skimage.io as io
import torch
from torch.utils.data import Dataset


def get_all_filenames(data_directory: str, extensions: list[str]):
result = []
for dirpath, dirnames, filenames in os.walk(data_directory):
for filename in filenames:
if Path(filename).suffix[1:] in extensions:
result.append(os.path.join(dirpath, filename))
for dirname in dirnames:
result.extend(get_all_filenames(dirname, extensions))
return result


class ImagesDataset(Dataset):
"""Create dataset with all image files in a directory."""

def __init__(
self,
data_directory: str,
extensions: list[str] = ["jpg"],
transform=None,
):
"""Initializes the data set.
Args:
data_directory (str): The data directory.
transform (torchvision.transforms, optional): A single or a set of
transformations to modify the images. Defaults to None.
"""

self.transform = transform
self.filenames = sorted(get_all_filenames(data_directory, extensions))

def __len__(self) -> int:
"""Return the number of items in the dataset.
Returns:
int: Number of items in dataset.
"""
return len(self.filenames)

def __getitem__(self, index: int) -> torch.Tensor:
"""Retrieves the item/items with the given indices from the dataset.
Args:
index: The index of the item to retrieve.
Returns:
data: Data of the item/items with the given indices.
"""
# Swap axis 0 and 2 to bring the color channel to the front
data = io.imread(self.filenames[index])
data = data.swapaxes(0, 2)
data = torch.Tensor(data)
if self.transform:
data = self.transform(data)
return data
2 changes: 1 addition & 1 deletion src/spherinator/data/shapes_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def setup(self, stage: str):
stage (str): Defines for which stage the data is needed.
For the moment just fitting is supported.
"""
if not stage in ["fit", "processing", "images", "thumbnail_images"]:
if stage not in ["fit", "processing", "images", "thumbnail_images"]:
raise ValueError(f"Stage {stage} not supported.")

if stage == "fit" and self.data_train is None:
Expand Down
10 changes: 10 additions & 0 deletions src/spherinator/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
"""

from .convolutional_decoder import ConvolutionalDecoder
from .convolutional_decoder_2 import ConvolutionalDecoder2
from .convolutional_decoder_224 import ConvolutionalDecoder224
from .convolutional_decoder_256 import ConvolutionalDecoder256
from .convolutional_encoder import ConvolutionalEncoder
from .convolutional_encoder_2 import ConvolutionalEncoder2
from .rotational2_autoencoder import Rotational2Autoencoder
from .rotational2_variational_autoencoder_power import (
Rotational2VariationalAutoencoderPower,
)
from .rotational_autoencoder import RotationalAutoencoder
from .rotational_variational_autoencoder_power import (
RotationalVariationalAutoencoderPower,
Expand All @@ -24,9 +30,13 @@

__all__ = [
"ConvolutionalDecoder",
"ConvolutionalDecoder2",
"ConvolutionalDecoder224",
"ConvolutionalDecoder256",
"ConvolutionalEncoder",
"ConvolutionalEncoder2",
"Rotational2Autoencoder",
"Rotational2VariationalAutoencoderPower",
"RotationalAutoencoder",
"RotationalVariationalAutoencoderPower",
"SpherinatorModule",
Expand Down
4 changes: 2 additions & 2 deletions src/spherinator/models/convolutional_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@


class ConvolutionalDecoder(pl.LightningModule):
def __init__(self, h_dim: int = 256):
def __init__(self, latent_dim: int):
super().__init__()

self.fc = nn.Linear(h_dim, 256 * 4 * 4)
self.fc = nn.Linear(latent_dim, 256 * 4 * 4)
self.deconv1 = nn.ConvTranspose2d(
in_channels=256, out_channels=128, kernel_size=(4, 4), stride=2, padding=1
) # 8x8
Expand Down
47 changes: 47 additions & 0 deletions src/spherinator/models/convolutional_decoder_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn


class ConvolutionalDecoder2(nn.Module):
def __init__(self, latent_dim: int):
super().__init__()

self.dec1 = nn.Sequential(
nn.Linear(latent_dim, 1024 * 4 * 4),
nn.Unflatten(1, (1024, 4, 4)),
nn.BatchNorm2d(1024),
nn.ReLU(),
) # 512 x 8 x 8
self.dec2 = nn.Sequential(
nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
) # 512 x 8 x 8
self.dec3 = nn.Sequential(
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
) # 512 x 16 x 16
self.dec4 = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
) # 256 x 32 x 32
self.dec5 = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
) # 128 x 64 x 64
self.dec6 = nn.Sequential(
nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),
nn.BatchNorm2d(3),
) # 3 x 128 x 128

def forward(self, x: torch.tensor) -> torch.tensor:
x = self.dec1(x)
x = self.dec2(x)
x = self.dec3(x)
x = self.dec4(x)
x = self.dec5(x)
x = self.dec6(x)
return x
4 changes: 2 additions & 2 deletions src/spherinator/models/convolutional_decoder_224.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@


class ConvolutionalDecoder224(pl.LightningModule):
def __init__(self, h_dim: int = 256):
def __init__(self, latent_dim: int):
"""Convolutional decoder for 224x224 images.
Input: h_dim
Output: 3x224x224
H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
"""
super().__init__()

self.fc = nn.Linear(h_dim, 256 * 4 * 4)
self.fc = nn.Linear(latent_dim, 256 * 4 * 4)
self.deconv1 = nn.ConvTranspose2d(
in_channels=256, out_channels=128, kernel_size=(3, 3), stride=2, padding=1
) # 7 = 6 - 2 + 2 + 1
Expand Down
Loading

0 comments on commit 416e2f6

Please sign in to comment.