Skip to content

Commit

Permalink
Merge pull request #33 from BerndDoser/devel
Browse files Browse the repository at this point in the history
Power spherical distribution
  • Loading branch information
BerndDoser authored Oct 20, 2023
2 parents 93b7af8 + dbb687f commit b4c1101
Show file tree
Hide file tree
Showing 39 changed files with 743 additions and 221 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__
*.ckpt
*.ncu-rep
HiPSter/
lightning_logs/
data/MNIST/
local/
wandb/
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[submodule "external/s-vae-pytorch"]
path = external/s-vae-pytorch
url = https://github.com/nicola-decao/s-vae-pytorch
[submodule "external/power_spherical"]
path = external/power_spherical
url = https://github.com/HITS-AIN/power_spherical.git
branch = jit
6 changes: 5 additions & 1 deletion .vscode/extensions.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
"ms-python.pylint",
"ms-toolsai.jupyter",
"streetsidesoftware.code-spell-checker",
"ms-python.isort"
"ms-python.isort",
"github.vscode-github-actions",
"eamodio.gitlens",
"nvidia.nsight-vscode-edition",
"ziruiwang.nvidia-monitor"
]
}
9 changes: 9 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: power spherical",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/main.py",
"args": "fit -c ${workspaceFolder}/experiments/shapes-power.yaml",
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Python: Current File",
"type": "python",
Expand Down
8 changes: 7 additions & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,22 @@
2. `illustris_sdss_data_module`: The illustris data module for training.
3. `galaxy_zoo_dataset`: Access to galaxy zoo data.
4. `galaxy_zoo_data_module`: The galaxy zoo data module for training.
5. `shapes_dataset`: Access to shapes data.
6. `shapes_data_module`: The shapes data module for training.
"""

from .illustris_sdss_dataset import IllustrisSdssDataset
from .illustris_sdss_data_module import IllustrisSdssDataModule
from .galaxy_zoo_dataset import GalaxyZooDataset
from .galaxy_zoo_data_module import GalaxyZooDataModule
from .shapes_dataset import ShapesDataset
from .shapes_data_module import ShapesDataModule

__all__ = [
'IllustrisSdssDataset',
'IllustrisSdssDataModule',
'GalaxyZooDataset',
'GalaxyZooDataModule'
'GalaxyZooDataModule',
'ShapesDataset',
'ShapesDataModule'
]
4 changes: 3 additions & 1 deletion data/galaxy_zoo_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
"""

import os

import numpy
import skimage.io as io
import torch
from torch.utils.data import Dataset
import skimage.io as io


class GalaxyZooDataset(Dataset):
""" Provides access to galaxy zoo images.
Expand Down
6 changes: 3 additions & 3 deletions data/illustris_sdss_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
""" Provides access to the Illustris sdss images.
"""
import os
from typing import List

import os
import numpy

import torch
from torch.utils.data import Dataset
from astropy.io import fits
from torch.utils.data import Dataset


class IllustrisSdssDataset(Dataset):
""" Provides access to Illustris sdss like images.
Expand Down
6 changes: 2 additions & 4 deletions data/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import numpy
import copy

import torch
import torchvision.transforms.functional as TF
from PIL import Image
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF


# FIXME: implement this
class DielemanTransformation():
Expand Down
106 changes: 106 additions & 0 deletions data/shapes_data_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
""" Defines access to the ShapesDataset.
"""
from typing import List

import lightning.pytorch as pl
from torch.utils.data import DataLoader
from torchvision import transforms

from data.shapes_dataset import ShapesDataset

class ShapesDataModule(pl.LightningDataModule):
""" Defines access to the ShapesDataset.
"""
def __init__(self,
data_directory: str,
shuffle: bool = True,
image_size: int = 91,
batch_size: int = 32,
num_workers: int = 1):
""" Initializes the data loader
Args:
data_directories (List[str]): The data directory
shuffle (bool, optional): Wether or not to shuffle whe reading. Defaults to True.
batch_size (int, optional): The batch size for training. Defaults to 32.
num_workers (int, optional): How many worker to use for loading. Defaults to 1.
"""
super().__init__()

self.data_directory = data_directory
self.shuffle = shuffle
self.image_size = image_size
self.batch_size = batch_size
self.num_workers = num_workers

self.transform_train = transforms.Compose([
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize((0,0,0), (290,290,290)),
transforms.Resize((self.image_size, self.image_size), antialias=True)
])
self.transform_predict = self.transform_train
self.transform_val = self.transform_train

self.data_train = None
self.dataloader_train = None
self.data_predict = None
self.dataloader_predict = None
self.data_val = None
self.dataloader_val = None

def setup(self, stage: str):
""" Sets up the data set and data loaders.
Args:
stage (str): Defines for which stage the data is needed.
For the moment just fitting is supported.
"""
if stage == "fit":
self.data_train = ShapesDataset(data_directory=self.data_directory,
transform=self.transform_train)

self.dataloader_train = DataLoader(self.data_train,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers)
if stage == "predict":
self.data_predict = ShapesDataset(data_directory=self.data_directory,
transform=self.transform_predict)

self.dataloader_predict = DataLoader(self.data_predict,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)

if stage == "val":
self.data_val = ShapesDataset(data_directory=self.data_directory,
transform=self.transform_val)

self.dataloader_val = DataLoader(self.data_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)

def train_dataloader(self):
""" Gets the data loader for training.
Returns:
torch.utils.data.DataLoader: The dataloader instance to use for training.
"""
return self.dataloader_train

def predict_dataloader(self):
""" Gets the data loader for prediction.
Returns:
torch.utils.data.DataLoader: The dataloader instance to use for prediction.
"""
return self.dataloader_predict

def val_dataloader(self):
""" Gets the data loader for validation.
Returns:
torch.utils.data.DataLoader: The dataloader instance to use for validation.
"""
return self.dataloader_val
54 changes: 54 additions & 0 deletions data/shapes_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
""" Test images with four shapes in random rotations.
"""
from typing import List

import os
import numpy as np

import torch
from torch.utils.data import Dataset

class ShapesDataset(Dataset):
""" Test images with four shapes in random rotations.
"""
def __init__(self,
data_directory: str,
transform = None):
""" Initializes an Illustris sdss data set.
Args:
data_directory (str): The data directory.
transform (torchvision.transforms.Compose, optional): A single or a set of
transformations to modify the images. Defaults to None.
"""
self.data_directory = data_directory
self.transform = transform
self.images = np.empty((0,64,64), np.float32)
for file in os.listdir(data_directory):
self.images = np.append(self.images, np.load(os.path.join(data_directory, file)),
axis=0).astype(np.float32)

def __len__(self):
""" Return the number of items in the dataset.
Returns:
int: Number of items in dataset.
"""
return len(self.images)

def __getitem__(self, idx):
""" Retrieves the item/items with the given indices from the dataset.
Args:
idx (int or tensor): The index of the item to retrieve.
Returns:
dictionary: A dictionary mapping image, filename and id.
"""
if torch.is_tensor(idx):
idx = idx.tolist()
image = torch.Tensor(self.images[idx])
if self.transform:
image = self.transform(image)
sample = {'image': image, 'id': idx}
return sample
24 changes: 24 additions & 0 deletions devel/power-jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import sys

import torch

script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(script_dir, '../'))

import models

CHECKPOINT_PATH = "/home/doserbd/ain-space/local/shapes-power/spherinator/w5z7hffp/checkpoints/epoch=60-train_loss=19.32.ckpt"

model = models.RotationalVariationalAutoencoderPower(z_dim=3)
model.load_state_dict(torch.load(CHECKPOINT_PATH)["state_dict"])
model.eval()

# Test the model with a dummy input
model(model.example_input_array)

# Convert the model to TorchScript
script = model.to_torchscript()

# Save the model
script.save("model.pt")
24 changes: 24 additions & 0 deletions devel/rot-vae-jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os
import sys

import torch

script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(script_dir, '../'))

import models

CHECKPOINT_PATH = "/home/doserbd/ain-space/local/shapes-rot-vae/spherinator/qtxqbr25/checkpoints/epoch=115-step=14500.ckpt"

model = models.RotationalVariationalAutoencoder(z_dim=3, distribution="vmf")
model.load_state_dict(torch.load(CHECKPOINT_PATH)["state_dict"])
model.eval()

# Test the model with a dummy input
model(model.example_input_array)

# Convert the model to TorchScript
script = model.to_torchscript()

# Save the model
script.save("model.pt")
18 changes: 18 additions & 0 deletions devel/simple-jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import lightning.pytorch as pl

class SimpleModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)

def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model
model = SimpleModel()
script = model.to_torchscript()

# save for use in production environment
torch.jit.save(script, "model.pt")
7 changes: 7 additions & 0 deletions devel/test-tensor-cores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import torch
import torch.nn

bsz , inf, outf = 256, 1024, 2048
tensor = torch.randn(bsz, inf).cuda().half()
layer = torch.nn.Linear(inf, outf).cuda().half()
layer(tensor)
39 changes: 0 additions & 39 deletions experiments/gz-svae.yaml

This file was deleted.

Loading

0 comments on commit b4c1101

Please sign in to comment.