-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from BerndDoser/devel
Power spherical distribution
- Loading branch information
Showing
39 changed files
with
743 additions
and
221 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
__pycache__ | ||
*.ckpt | ||
*.ncu-rep | ||
HiPSter/ | ||
lightning_logs/ | ||
data/MNIST/ | ||
local/ | ||
wandb/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
[submodule "external/s-vae-pytorch"] | ||
path = external/s-vae-pytorch | ||
url = https://github.com/nicola-decao/s-vae-pytorch | ||
[submodule "external/power_spherical"] | ||
path = external/power_spherical | ||
url = https://github.com/HITS-AIN/power_spherical.git | ||
branch = jit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
""" Defines access to the ShapesDataset. | ||
""" | ||
from typing import List | ||
|
||
import lightning.pytorch as pl | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms | ||
|
||
from data.shapes_dataset import ShapesDataset | ||
|
||
class ShapesDataModule(pl.LightningDataModule): | ||
""" Defines access to the ShapesDataset. | ||
""" | ||
def __init__(self, | ||
data_directory: str, | ||
shuffle: bool = True, | ||
image_size: int = 91, | ||
batch_size: int = 32, | ||
num_workers: int = 1): | ||
""" Initializes the data loader | ||
Args: | ||
data_directories (List[str]): The data directory | ||
shuffle (bool, optional): Wether or not to shuffle whe reading. Defaults to True. | ||
batch_size (int, optional): The batch size for training. Defaults to 32. | ||
num_workers (int, optional): How many worker to use for loading. Defaults to 1. | ||
""" | ||
super().__init__() | ||
|
||
self.data_directory = data_directory | ||
self.shuffle = shuffle | ||
self.image_size = image_size | ||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
|
||
self.transform_train = transforms.Compose([ | ||
transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||
transforms.Normalize((0,0,0), (290,290,290)), | ||
transforms.Resize((self.image_size, self.image_size), antialias=True) | ||
]) | ||
self.transform_predict = self.transform_train | ||
self.transform_val = self.transform_train | ||
|
||
self.data_train = None | ||
self.dataloader_train = None | ||
self.data_predict = None | ||
self.dataloader_predict = None | ||
self.data_val = None | ||
self.dataloader_val = None | ||
|
||
def setup(self, stage: str): | ||
""" Sets up the data set and data loaders. | ||
Args: | ||
stage (str): Defines for which stage the data is needed. | ||
For the moment just fitting is supported. | ||
""" | ||
if stage == "fit": | ||
self.data_train = ShapesDataset(data_directory=self.data_directory, | ||
transform=self.transform_train) | ||
|
||
self.dataloader_train = DataLoader(self.data_train, | ||
batch_size=self.batch_size, | ||
shuffle=self.shuffle, | ||
num_workers=self.num_workers) | ||
if stage == "predict": | ||
self.data_predict = ShapesDataset(data_directory=self.data_directory, | ||
transform=self.transform_predict) | ||
|
||
self.dataloader_predict = DataLoader(self.data_predict, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=self.num_workers) | ||
|
||
if stage == "val": | ||
self.data_val = ShapesDataset(data_directory=self.data_directory, | ||
transform=self.transform_val) | ||
|
||
self.dataloader_val = DataLoader(self.data_val, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=self.num_workers) | ||
|
||
def train_dataloader(self): | ||
""" Gets the data loader for training. | ||
Returns: | ||
torch.utils.data.DataLoader: The dataloader instance to use for training. | ||
""" | ||
return self.dataloader_train | ||
|
||
def predict_dataloader(self): | ||
""" Gets the data loader for prediction. | ||
Returns: | ||
torch.utils.data.DataLoader: The dataloader instance to use for prediction. | ||
""" | ||
return self.dataloader_predict | ||
|
||
def val_dataloader(self): | ||
""" Gets the data loader for validation. | ||
Returns: | ||
torch.utils.data.DataLoader: The dataloader instance to use for validation. | ||
""" | ||
return self.dataloader_val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
""" Test images with four shapes in random rotations. | ||
""" | ||
from typing import List | ||
|
||
import os | ||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
class ShapesDataset(Dataset): | ||
""" Test images with four shapes in random rotations. | ||
""" | ||
def __init__(self, | ||
data_directory: str, | ||
transform = None): | ||
""" Initializes an Illustris sdss data set. | ||
Args: | ||
data_directory (str): The data directory. | ||
transform (torchvision.transforms.Compose, optional): A single or a set of | ||
transformations to modify the images. Defaults to None. | ||
""" | ||
self.data_directory = data_directory | ||
self.transform = transform | ||
self.images = np.empty((0,64,64), np.float32) | ||
for file in os.listdir(data_directory): | ||
self.images = np.append(self.images, np.load(os.path.join(data_directory, file)), | ||
axis=0).astype(np.float32) | ||
|
||
def __len__(self): | ||
""" Return the number of items in the dataset. | ||
Returns: | ||
int: Number of items in dataset. | ||
""" | ||
return len(self.images) | ||
|
||
def __getitem__(self, idx): | ||
""" Retrieves the item/items with the given indices from the dataset. | ||
Args: | ||
idx (int or tensor): The index of the item to retrieve. | ||
Returns: | ||
dictionary: A dictionary mapping image, filename and id. | ||
""" | ||
if torch.is_tensor(idx): | ||
idx = idx.tolist() | ||
image = torch.Tensor(self.images[idx]) | ||
if self.transform: | ||
image = self.transform(image) | ||
sample = {'image': image, 'id': idx} | ||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(os.path.join(script_dir, '../')) | ||
|
||
import models | ||
|
||
CHECKPOINT_PATH = "/home/doserbd/ain-space/local/shapes-power/spherinator/w5z7hffp/checkpoints/epoch=60-train_loss=19.32.ckpt" | ||
|
||
model = models.RotationalVariationalAutoencoderPower(z_dim=3) | ||
model.load_state_dict(torch.load(CHECKPOINT_PATH)["state_dict"]) | ||
model.eval() | ||
|
||
# Test the model with a dummy input | ||
model(model.example_input_array) | ||
|
||
# Convert the model to TorchScript | ||
script = model.to_torchscript() | ||
|
||
# Save the model | ||
script.save("model.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(os.path.join(script_dir, '../')) | ||
|
||
import models | ||
|
||
CHECKPOINT_PATH = "/home/doserbd/ain-space/local/shapes-rot-vae/spherinator/qtxqbr25/checkpoints/epoch=115-step=14500.ckpt" | ||
|
||
model = models.RotationalVariationalAutoencoder(z_dim=3, distribution="vmf") | ||
model.load_state_dict(torch.load(CHECKPOINT_PATH)["state_dict"]) | ||
model.eval() | ||
|
||
# Test the model with a dummy input | ||
model(model.example_input_array) | ||
|
||
# Convert the model to TorchScript | ||
script = model.to_torchscript() | ||
|
||
# Save the model | ||
script.save("model.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
import lightning.pytorch as pl | ||
|
||
class SimpleModel(pl.LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.l1 = torch.nn.Linear(in_features=64, out_features=4) | ||
|
||
def forward(self, x): | ||
return torch.relu(self.l1(x.view(x.size(0), -1))) | ||
|
||
|
||
# create the model | ||
model = SimpleModel() | ||
script = model.to_torchscript() | ||
|
||
# save for use in production environment | ||
torch.jit.save(script, "model.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import torch | ||
import torch.nn | ||
|
||
bsz , inf, outf = 256, 1024, 2048 | ||
tensor = torch.randn(bsz, inf).cuda().half() | ||
layer = torch.nn.Linear(inf, outf).cuda().half() | ||
layer(tensor) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.