Skip to content

Commit

Permalink
adding model input size
Browse files Browse the repository at this point in the history
  • Loading branch information
kai-polsterer committed Oct 29, 2023
1 parent bc18db7 commit a997aec
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
3 changes: 1 addition & 2 deletions hipster.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(self, output_folder, title, max_order=3, hierarchy=1, crop_size=64,
self.hierarchy = hierarchy
self.crop_size = crop_size
self.output_size = output_size
self.model_size = 128
self.distortion_correction = distortion_correction

def check_folders(self, base_folder):
Expand Down Expand Up @@ -270,7 +269,7 @@ def project_dataset(self, model, dataloader, rotation_steps):
for r in range(rotation_steps):
rot_images = functional.rotate(images, 360/rotation_steps*r, expand=False) # rotate
crop_images = functional.center_crop(rot_images, [self.crop_size, self.crop_size]) # crop
scaled_images = functional.resize(crop_images, [self.model_size, self.model_size], antialias=False) # scale
scaled_images = functional.resize(crop_images, [model.get_input_size(), model.get_input_size()], antialias=False) # scale
with torch.no_grad():
coordinates = model.project(scaled_images)
reconstruction = model.reconstruct(coordinates)
Expand Down
7 changes: 5 additions & 2 deletions models/rotational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(self,
self.deconv6 = nn.ConvTranspose2d(in_channels=16, out_channels=3,
kernel_size=(2,2), stride=1, padding=0) #128x128

def get_input_size(self):
return self.input_size

def encode(self, x):
x = F.relu(self.conv0(x))
x = self.pool0(x)
Expand Down Expand Up @@ -99,8 +102,8 @@ def training_step(self, train_batch, _batch_idx):
losses = torch.zeros(images.shape[0], self.rotations)
for i in range(self.rotations):
rotate = functional.rotate(images, 360.0/self.rotations*i, expand=False) # rotate
crop = functional.center_crop(rotate, [256,256]) # crop
scaled = functional.resize(crop, [128,128], antialias=False) # scale
crop = functional.center_crop(rotate, [256,256]) # crop #TODO config?
scaled = functional.resize(crop, [self.input_size, self.input_size], antialias=False) # scale
reconstruction, coordinates = self.forward(scaled)
losses[:,i] = self.spherical_loss(scaled, reconstruction, coordinates)
loss = torch.mean(torch.min(losses, dim=1)[0])
Expand Down
3 changes: 3 additions & 0 deletions models/rotational_variational_autoencoder_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def __init__(self,
self.deconv6 = nn.ConvTranspose2d(in_channels=16, out_channels=3,
kernel_size=(2,2), stride=1, padding=0) #128x128

def get_input_size(self):
return self.input_size

def encode(self, x):
x = F.relu(self.conv0(x))
x = self.pool0(x)
Expand Down
4 changes: 4 additions & 0 deletions models/spherinator_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ class SpherinatorModule(ABC, pl.LightningModule):
"""
Abstract base class for all spherinator modules to ensure that all methods for hipster are implemented.
"""
@abstractmethod
def get_input_size(self):
"""Returns the size of the images the model takes as input and generates as output.
"""

@abstractmethod
def project(self, images):
Expand Down

0 comments on commit a997aec

Please sign in to comment.