diff --git a/hipster.py b/hipster.py index 96d6f81..31836e7 100755 --- a/hipster.py +++ b/hipster.py @@ -50,7 +50,6 @@ def __init__(self, output_folder, title, max_order=3, hierarchy=1, crop_size=64, self.hierarchy = hierarchy self.crop_size = crop_size self.output_size = output_size - self.model_size = 128 self.distortion_correction = distortion_correction def check_folders(self, base_folder): @@ -270,7 +269,7 @@ def project_dataset(self, model, dataloader, rotation_steps): for r in range(rotation_steps): rot_images = functional.rotate(images, 360/rotation_steps*r, expand=False) # rotate crop_images = functional.center_crop(rot_images, [self.crop_size, self.crop_size]) # crop - scaled_images = functional.resize(crop_images, [self.model_size, self.model_size], antialias=False) # scale + scaled_images = functional.resize(crop_images, [model.get_input_size(), model.get_input_size()], antialias=False) # scale with torch.no_grad(): coordinates = model.project(scaled_images) reconstruction = model.reconstruct(coordinates) diff --git a/models/rotational_autoencoder.py b/models/rotational_autoencoder.py index 7b6fb5c..30e96c8 100644 --- a/models/rotational_autoencoder.py +++ b/models/rotational_autoencoder.py @@ -51,6 +51,9 @@ def __init__(self, self.deconv6 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(2,2), stride=1, padding=0) #128x128 + def get_input_size(self): + return self.input_size + def encode(self, x): x = F.relu(self.conv0(x)) x = self.pool0(x) @@ -99,8 +102,8 @@ def training_step(self, train_batch, _batch_idx): losses = torch.zeros(images.shape[0], self.rotations) for i in range(self.rotations): rotate = functional.rotate(images, 360.0/self.rotations*i, expand=False) # rotate - crop = functional.center_crop(rotate, [256,256]) # crop - scaled = functional.resize(crop, [128,128], antialias=False) # scale + crop = functional.center_crop(rotate, [256,256]) # crop #TODO config? + scaled = functional.resize(crop, [self.input_size, self.input_size], antialias=False) # scale reconstruction, coordinates = self.forward(scaled) losses[:,i] = self.spherical_loss(scaled, reconstruction, coordinates) loss = torch.mean(torch.min(losses, dim=1)[0]) diff --git a/models/rotational_variational_autoencoder_power.py b/models/rotational_variational_autoencoder_power.py index e108e38..2dc9303 100644 --- a/models/rotational_variational_autoencoder_power.py +++ b/models/rotational_variational_autoencoder_power.py @@ -83,6 +83,9 @@ def __init__(self, self.deconv6 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(2,2), stride=1, padding=0) #128x128 + def get_input_size(self): + return self.input_size + def encode(self, x): x = F.relu(self.conv0(x)) x = self.pool0(x) diff --git a/models/spherinator_module.py b/models/spherinator_module.py index fa48bc9..6342d52 100644 --- a/models/spherinator_module.py +++ b/models/spherinator_module.py @@ -7,6 +7,10 @@ class SpherinatorModule(ABC, pl.LightningModule): """ Abstract base class for all spherinator modules to ensure that all methods for hipster are implemented. """ + @abstractmethod + def get_input_size(self): + """Returns the size of the images the model takes as input and generates as output. + """ @abstractmethod def project(self, images):