Skip to content

Commit

Permalink
Revert "disable autocast for torch.no_grad for all autoencoders"
Browse files Browse the repository at this point in the history
This reverts commit a1c54da.
  • Loading branch information
BerndDoser committed Aug 13, 2024
1 parent 7d31cf0 commit 518de71
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 82 deletions.
55 changes: 26 additions & 29 deletions src/spherinator/models/rotational2_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,40 +68,37 @@ def forward(self, x):
return recon

def training_step(self, batch, batch_idx):
with torch.autocast("cuda", enabled=False):

with torch.no_grad():
crop = functional.center_crop(batch, [self.crop_size, self.crop_size])
scaled = functional.resize(
crop, [self.input_size, self.input_size], antialias=True
)

recon = self.forward(scaled)
loss = self.reconstruction_loss(scaled, recon)

for i in range(1, self.rotations):
with torch.no_grad():
crop = functional.center_crop(batch, [self.crop_size, self.crop_size])
rotate = functional.rotate(
batch, 360.0 / self.rotations * i, expand=False
)
crop = functional.center_crop(rotate, [self.crop_size, self.crop_size])
scaled = functional.resize(
crop, [self.input_size, self.input_size], antialias=True
)

with torch.autocast("cuda", enabled=True):
recon = self.forward(scaled)
loss = self.reconstruction_loss(scaled, recon)

for i in range(1, self.rotations):
with torch.no_grad():
rotate = functional.rotate(
batch, 360.0 / self.rotations * i, expand=False
)
crop = functional.center_crop(
rotate, [self.crop_size, self.crop_size]
)
scaled = functional.resize(
crop, [self.input_size, self.input_size], antialias=True
)

loss = torch.min(loss, self.reconstruction_loss(scaled, recon))

# divide by the brightness of the image
if self.norm_brightness:
loss = loss / torch.sum(scaled, (1, 2, 3)) * self.total_input_size

loss = loss.mean()

self.log("train_loss", loss, prog_bar=True)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
return loss
loss = torch.min(loss, self.reconstruction_loss(scaled, recon))

# divide by the brightness of the image
if self.norm_brightness:
loss = loss / torch.sum(scaled, (1, 2, 3)) * self.total_input_size

loss = loss.mean()

self.log("train_loss", loss, prog_bar=True)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
return loss

def configure_optimizers(self):
"""Default Adam optimizer if missing from the configuration file."""
Expand Down
61 changes: 28 additions & 33 deletions src/spherinator/models/rotational2_variational_autoencoder_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,45 +95,40 @@ def forward(self, x):
return (z_location, z_scale), (q_z, p_z), z, recon

def training_step(self, batch, batch_idx):
with torch.autocast("cuda", enabled=False):

with torch.no_grad():
crop = functional.center_crop(batch, [self.crop_size, self.crop_size])
scaled = functional.resize(
crop, [self.input_size, self.input_size], antialias=True
)

(z_location, z_scale), (q_z, p_z), _, recon = self.forward(scaled)
loss_recon = self.reconstruction_loss(scaled, recon)

for i in range(1, self.rotations):
with torch.no_grad():
crop = functional.center_crop(batch, [self.crop_size, self.crop_size])
rotate = functional.rotate(
batch, 360.0 / self.rotations * i, expand=False
)
crop = functional.center_crop(rotate, [self.crop_size, self.crop_size])
scaled = functional.resize(
crop, [self.input_size, self.input_size], antialias=True
)

with torch.autocast("cuda", enabled=True):
(z_location, z_scale), (q_z, p_z), _, recon = self.forward(scaled)
loss_recon = self.reconstruction_loss(scaled, recon)

for i in range(1, self.rotations):
with torch.no_grad():
rotate = functional.rotate(
batch, 360.0 / self.rotations * i, expand=False
)
crop = functional.center_crop(
rotate, [self.crop_size, self.crop_size]
)
scaled = functional.resize(
crop, [self.input_size, self.input_size], antialias=True
)

loss_recon = torch.min(
loss_recon, self.reconstruction_loss(scaled, recon)
)
loss_recon = torch.min(loss_recon, self.reconstruction_loss(scaled, recon))

loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z) * self.beta
loss = (loss_recon + loss_KL).mean()
loss_recon = loss_recon.mean()
loss_KL = loss_KL.mean()

loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z) * self.beta
loss = (loss_recon + loss_KL).mean()
loss_recon = loss_recon.mean()
loss_KL = loss_KL.mean()

self.log("train_loss", loss, prog_bar=True)
self.log("loss_recon", loss_recon, prog_bar=True)
self.log("loss_KL", loss_KL)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
self.log("mean(z_location)", torch.mean(z_location))
self.log("mean(z_scale)", torch.mean(z_scale))
return loss
self.log("train_loss", loss, prog_bar=True)
self.log("loss_recon", loss_recon, prog_bar=True)
self.log("loss_KL", loss_KL)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
self.log("mean(z_location)", torch.mean(z_location))
self.log("mean(z_scale)", torch.mean(z_scale))
return loss

def configure_optimizers(self):
"""Default Adam optimizer if missing from the configuration file."""
Expand Down
36 changes: 16 additions & 20 deletions src/spherinator/models/rotational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,22 @@ def forward(self, x):
return recon

def training_step(self, batch, batch_idx):
with torch.autocast("cuda", enabled=False):
best_scaled_image, _, _, _ = self.find_best_rotation(batch)

with torch.autocast("cuda", enabled=True):
recon = self.forward(best_scaled_image)
loss = self.reconstruction_loss(best_scaled_image, recon)

# divide by the brightness of the image
if self.norm_brightness:
loss = (
loss
/ torch.sum(best_scaled_image, (1, 2, 3))
* self.total_input_size
)

loss = loss.mean()

self.log("train_loss", loss, prog_bar=True)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
return loss

best_scaled_image, _, _, _ = self.find_best_rotation(batch)
recon = self.forward(best_scaled_image)
loss = self.reconstruction_loss(best_scaled_image, recon)

# divide by the brightness of the image
if self.norm_brightness:
loss = (
loss / torch.sum(best_scaled_image, (1, 2, 3)) * self.total_input_size
)

loss = loss.mean()

self.log("train_loss", loss, prog_bar=True)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
return loss

def configure_optimizers(self):
"""Default Adam optimizer if missing from the configuration file."""
Expand Down

0 comments on commit 518de71

Please sign in to comment.