Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative rotational invariance and convolution neural networks #171

Merged
merged 25 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a1c1627
train with not rotated image and reconstruction loss as minimum over …
BerndDoser May 14, 2024
c7a15c0
improve cnn encoder/decoder architecture:
BerndDoser May 17, 2024
5348385
add ImageDataset
BerndDoser May 23, 2024
4a1433e
Merge remote-tracking branch 'origin/main' into rot-loss
BerndDoser May 23, 2024
b5752dd
Merge remote-tracking branch 'ain/main' into rot-loss
BerndDoser May 23, 2024
8e73df8
Add ImagesDataset and ImagesDataModule
BerndDoser May 23, 2024
430c480
replace PIL by skimage.io for better performance
BerndDoser May 23, 2024
cc27bf7
update plain autoencoder
BerndDoser May 24, 2024
7c2fb24
add pokemon config
BerndDoser May 24, 2024
741d758
Improved CNN architecture
BerndDoser May 24, 2024
d18dd01
Merge branch 'main' into rot-loss
BerndDoser May 27, 2024
693a300
rename h_dim and z_dim into latent_dim ...
BerndDoser May 29, 2024
601eaf7
normalize loss by brightness
BerndDoser May 29, 2024
715ed7a
remove sigmoid from last layer of decoder
BerndDoser Jun 4, 2024
bb1e48c
Merge branch 'main' into rot-loss
BerndDoser Jun 7, 2024
8582fc8
Merge branch 'main' into rot-loss
BerndDoser Jun 26, 2024
beee3ea
use Optional to support python 3.9
BerndDoser Jun 26, 2024
bfb5f60
Revert "train with not rotated image and reconstruction loss as minim…
BerndDoser Jun 27, 2024
a16f967
git ignore sarif files
BerndDoser Jun 27, 2024
ea158f8
Rotational2Autoencoder, which takes the minimal loss of all rotated i…
BerndDoser Jun 27, 2024
8a76c7c
switch off autocast for torch.no_grad
BerndDoser Jul 10, 2024
9de1353
torch dynamo export not supported for python 3.12, skip tests
BerndDoser Jul 10, 2024
8578d9a
fix flake8 warnings
BerndDoser Jul 22, 2024
a1c54da
disable autocast for torch.no_grad for all autoencoders
BerndDoser Jul 24, 2024
574960f
poetry update
BerndDoser Jul 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ __pycache__
.venv*/
*.ckpt
*.ncu-rep
*.sarif
config.yaml
dist/
docs/_build/
docs/html/
HiPSter/
lightning_logs/
wandb/
docs/_build/
docs/html/
66 changes: 66 additions & 0 deletions experiments/pokemon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
seed_everything: 42

model:
class_path: spherinator.models.RotationalVariationalAutoencoderPower
init_args:
encoder:
class_path: spherinator.models.ConvolutionalEncoder
decoder:
class_path: spherinator.models.ConvolutionalDecoder
h_dim: 256
z_dim: 3
image_size: 224
rotations: 1
beta: 1.0e-3

data:
class_path: spherinator.data.ImagesDataModule
init_args:
data_directory: /local_data/doserbd/data/pokemon
extensions: ['jpg']
image_size: 224
batch_size: 32
shuffle: True
num_workers: 16

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1.e-3

lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: min
factor: 0.1
patience: 5
cooldown: 5
min_lr: 1.e-6
monitor: train_loss
verbose: True

trainer:
max_epochs: -1
accelerator: gpu
devices: [3]
precision: 32
callbacks:
- class_path: spherinator.callbacks.LogReconstructionCallback
init_args:
num_samples: 6
# - class_path: lightning.pytorch.callbacks.ModelCheckpoint
# init_args:
# monitor: train_loss
# filename: "{epoch}-{train_loss:.2f}"
# save_top_k: 3
# mode: min
# every_n_epochs: 1
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
log_model: True
entity: ain-space
tags:
- rot-loss
- pokemon
1,049 changes: 531 additions & 518 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/spherinator/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .illustris_sdss_data_module import IllustrisSdssDataModule
from .illustris_sdss_dataset import IllustrisSdssDataset
from .illustris_sdss_dataset_with_metadata import IllustrisSdssDatasetWithMetadata
from .images_data_module import ImagesDataModule
from .images_dataset import ImagesDataset
from .shapes_data_module import ShapesDataModule
from .shapes_dataset import ShapesDataset
from .spherinator_data_module import SpherinatorDataModule
Expand All @@ -18,6 +20,8 @@
"IllustrisSdssDataModule",
"IllustrisSdssDataset",
"IllustrisSdssDatasetWithMetadata",
"ImagesDataModule",
"ImagesDataset",
"ShapesDataModule",
"ShapesDataset",
"SpherinatorDataModule",
Expand Down
2 changes: 1 addition & 1 deletion src/spherinator/data/galaxy_zoo_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def setup(self, stage: str):
Args:
stage (str): Defines for which stage the data is needed.
"""
if not stage in ["fit", "processing", "images", "thumbnail_images"]:
if stage not in ["fit", "processing", "images", "thumbnail_images"]:
raise ValueError(f"Stage {stage} not supported.")

if stage == "fit" and self.data_train is None:
Expand Down
2 changes: 1 addition & 1 deletion src/spherinator/data/illustris_sdss_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def setup(self, stage: str):
Args:
stage (str): Defines for which stage the data is needed.
"""
if not stage in ["fit", "processing", "images", "thumbnail_images"]:
if stage not in ["fit", "processing", "images", "thumbnail_images"]:
raise ValueError(f"Stage {stage} not supported.")

if stage == "fit" and self.data_train is None:
Expand Down
85 changes: 85 additions & 0 deletions src/spherinator/data/images_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
import torchvision.transforms.v2 as transforms
from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader

from spherinator.data.images_dataset import ImagesDataset


class ImagesDataModule(LightningDataModule):
"""Defines access to the ImagesDataset."""

def __init__(
self,
data_directory: str,
extensions: list[str] = ["jpg"],
shuffle: bool = True,
image_size: int = 64,
batch_size: int = 32,
num_workers: int = 1,
):
"""Initializes the data loader

Args:
data_directory (str): The data directory
shuffle (bool, optional): Wether or not to shuffle whe reading. Defaults to True.
image_size (int, optional): The size of the images. Defaults to 64.
batch_size (int, optional): The batch size for training. Defaults to 32.
num_workers (int, optional): How many worker to use for loading. Defaults to 1.
download (bool, optional): Wether or not to download the data. Defaults to False.
"""
super().__init__()

self.data_directory = data_directory
self.extensions = extensions
self.shuffle = shuffle
self.image_size = image_size
self.batch_size = batch_size
self.num_workers = num_workers

self.data_train = None
self.dataloader_train = None

self.transform_train = transforms.Compose(
[
transforms.Resize((self.image_size, self.image_size), antialias=True),
transforms.Lambda( # Normalize
lambda x: (x - torch.min(x)) / (torch.max(x) - torch.min(x))
),
]
)
self.transform_processing = self.transform_train
self.transform_images = self.transform_train
self.transform_thumbnail_images = transforms.Compose(
[
self.transform_train,
transforms.Resize((100, 100), antialias=True),
]
)

def setup(self, stage: str):
"""Sets up the data set and data loaders.

Args:
stage (str): Defines for which stage the data is needed.
For the moment just fitting is supported.
"""

if stage == "fit" and self.data_train is None:
self.data_train = ImagesDataset(
data_directory=self.data_directory,
extensions=self.extensions,
transform=self.transform_train,
)
self.dataloader_train = DataLoader(
self.data_train,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
)
else:
raise ValueError(f"Stage {stage} not supported.")

def train_dataloader(self):
"""Gets the data loader for training."""
return self.dataloader_train
67 changes: 67 additions & 0 deletions src/spherinator/data/images_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
""" Create dataset with all image files in a directory.
"""

import os
from pathlib import Path

import skimage.io as io
import torch
from torch.utils.data import Dataset


def get_all_filenames(data_directory: str, extensions: list[str]):
result = []
for dirpath, dirnames, filenames in os.walk(data_directory):
for filename in filenames:
if Path(filename).suffix[1:] in extensions:
result.append(os.path.join(dirpath, filename))
for dirname in dirnames:
result.extend(get_all_filenames(dirname, extensions))
return result


class ImagesDataset(Dataset):
"""Create dataset with all image files in a directory."""

def __init__(
self,
data_directory: str,
extensions: list[str] = ["jpg"],
transform=None,
):
"""Initializes the data set.

Args:
data_directory (str): The data directory.
transform (torchvision.transforms, optional): A single or a set of
transformations to modify the images. Defaults to None.
"""

self.transform = transform
self.filenames = sorted(get_all_filenames(data_directory, extensions))

def __len__(self) -> int:
"""Return the number of items in the dataset.

Returns:
int: Number of items in dataset.
"""
return len(self.filenames)

def __getitem__(self, index: int) -> torch.Tensor:
"""Retrieves the item/items with the given indices from the dataset.

Args:
index: The index of the item to retrieve.

Returns:
data: Data of the item/items with the given indices.

"""
# Swap axis 0 and 2 to bring the color channel to the front
data = io.imread(self.filenames[index])
data = data.swapaxes(0, 2)
data = torch.Tensor(data)
if self.transform:
data = self.transform(data)
return data
2 changes: 1 addition & 1 deletion src/spherinator/data/shapes_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def setup(self, stage: str):
stage (str): Defines for which stage the data is needed.
For the moment just fitting is supported.
"""
if not stage in ["fit", "processing", "images", "thumbnail_images"]:
if stage not in ["fit", "processing", "images", "thumbnail_images"]:
raise ValueError(f"Stage {stage} not supported.")

if stage == "fit" and self.data_train is None:
Expand Down
10 changes: 10 additions & 0 deletions src/spherinator/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
"""

from .convolutional_decoder import ConvolutionalDecoder
from .convolutional_decoder_2 import ConvolutionalDecoder2
from .convolutional_decoder_224 import ConvolutionalDecoder224
from .convolutional_decoder_256 import ConvolutionalDecoder256
from .convolutional_encoder import ConvolutionalEncoder
from .convolutional_encoder_2 import ConvolutionalEncoder2
from .rotational2_autoencoder import Rotational2Autoencoder
from .rotational2_variational_autoencoder_power import (
Rotational2VariationalAutoencoderPower,
)
from .rotational_autoencoder import RotationalAutoencoder
from .rotational_variational_autoencoder_power import (
RotationalVariationalAutoencoderPower,
Expand All @@ -24,9 +30,13 @@

__all__ = [
"ConvolutionalDecoder",
"ConvolutionalDecoder2",
"ConvolutionalDecoder224",
"ConvolutionalDecoder256",
"ConvolutionalEncoder",
"ConvolutionalEncoder2",
"Rotational2Autoencoder",
"Rotational2VariationalAutoencoderPower",
"RotationalAutoencoder",
"RotationalVariationalAutoencoderPower",
"SpherinatorModule",
Expand Down
4 changes: 2 additions & 2 deletions src/spherinator/models/convolutional_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@


class ConvolutionalDecoder(pl.LightningModule):
def __init__(self, h_dim: int = 256):
def __init__(self, latent_dim: int):
super().__init__()

self.fc = nn.Linear(h_dim, 256 * 4 * 4)
self.fc = nn.Linear(latent_dim, 256 * 4 * 4)
self.deconv1 = nn.ConvTranspose2d(
in_channels=256, out_channels=128, kernel_size=(4, 4), stride=2, padding=1
) # 8x8
Expand Down
47 changes: 47 additions & 0 deletions src/spherinator/models/convolutional_decoder_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.nn as nn


class ConvolutionalDecoder2(nn.Module):
def __init__(self, latent_dim: int):
super().__init__()

self.dec1 = nn.Sequential(
nn.Linear(latent_dim, 1024 * 4 * 4),
nn.Unflatten(1, (1024, 4, 4)),
nn.BatchNorm2d(1024),
nn.ReLU(),
) # 512 x 8 x 8
self.dec2 = nn.Sequential(
nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
) # 512 x 8 x 8
self.dec3 = nn.Sequential(
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
) # 512 x 16 x 16
self.dec4 = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
) # 256 x 32 x 32
self.dec5 = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
) # 128 x 64 x 64
self.dec6 = nn.Sequential(
nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),
nn.BatchNorm2d(3),
) # 3 x 128 x 128

def forward(self, x: torch.tensor) -> torch.tensor:
x = self.dec1(x)
x = self.dec2(x)
x = self.dec3(x)
x = self.dec4(x)
x = self.dec5(x)
x = self.dec6(x)
return x
4 changes: 2 additions & 2 deletions src/spherinator/models/convolutional_decoder_224.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@


class ConvolutionalDecoder224(pl.LightningModule):
def __init__(self, h_dim: int = 256):
def __init__(self, latent_dim: int):
"""Convolutional decoder for 224x224 images.
Input: h_dim
Output: 3x224x224
H_out = (H_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
"""
super().__init__()

self.fc = nn.Linear(h_dim, 256 * 4 * 4)
self.fc = nn.Linear(latent_dim, 256 * 4 * 4)
self.deconv1 = nn.ConvTranspose2d(
in_channels=256, out_channels=128, kernel_size=(3, 3), stride=2, padding=1
) # 7 = 6 - 2 + 2 + 1
Expand Down
Loading
Loading