Skip to content

Commit

Permalink
add input_size to vmf autoencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
BerndDoser committed Oct 30, 2023
1 parent a997aec commit 326966d
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions models/rotational_variational_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import math

import torch
import torch.linalg
Expand Down Expand Up @@ -39,11 +40,20 @@ def __init__(self,
"""
super().__init__()
self.save_hyperparameters()
self.example_input_array = torch.randn(1, 3, 64, 64)

self.h_dim, self.z_dim, self.distribution = h_dim, z_dim, distribution
self.h_dim = h_dim
self.z_dim = z_dim
self.distribution = distribution
self.image_size = image_size
self.rotations, self.beta, self.spherical_loss_weight = rotations, beta, spherical_loss_weight
self.rotations = rotations
self.beta = beta
self.spherical_loss_weight = spherical_loss_weight

self.crop_size = int(self.image_size * math.sqrt(2) / 2)
self.input_size = 64
self.total_input_size = self.input_size * self.input_size * 3

self.example_input_array = torch.randn(1, 3, self.input_size, self.input_size)

self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(5,5), stride=2, padding=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5,5), stride=2, padding=2)
Expand Down Expand Up @@ -71,6 +81,9 @@ def __init__(self,
self.deconv4 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(4,4), stride=2, padding=1)
self.deconv5 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(5,5), stride=1, padding=2)

def get_input_size(self):
return self.input_size

def encode(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
Expand Down

0 comments on commit 326966d

Please sign in to comment.