diff --git a/models/rotational_variational_autoencoder.py b/models/rotational_variational_autoencoder.py index 999d809..78d0704 100644 --- a/models/rotational_variational_autoencoder.py +++ b/models/rotational_variational_autoencoder.py @@ -1,5 +1,6 @@ import os import sys +import math import torch import torch.linalg @@ -39,11 +40,20 @@ def __init__(self, """ super().__init__() self.save_hyperparameters() - self.example_input_array = torch.randn(1, 3, 64, 64) - self.h_dim, self.z_dim, self.distribution = h_dim, z_dim, distribution + self.h_dim = h_dim + self.z_dim = z_dim + self.distribution = distribution self.image_size = image_size - self.rotations, self.beta, self.spherical_loss_weight = rotations, beta, spherical_loss_weight + self.rotations = rotations + self.beta = beta + self.spherical_loss_weight = spherical_loss_weight + + self.crop_size = int(self.image_size * math.sqrt(2) / 2) + self.input_size = 64 + self.total_input_size = self.input_size * self.input_size * 3 + + self.example_input_array = torch.randn(1, 3, self.input_size, self.input_size) self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(5,5), stride=2, padding=2) self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5,5), stride=2, padding=2) @@ -71,6 +81,9 @@ def __init__(self, self.deconv4 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(4,4), stride=2, padding=1) self.deconv5 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(5,5), stride=1, padding=2) + def get_input_size(self): + return self.input_size + def encode(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x))