Skip to content

Commit

Permalink
Revert "switch off autocast for torch.no_grad"
Browse files Browse the repository at this point in the history
This reverts commit 8a76c7c.
  • Loading branch information
BerndDoser committed Aug 13, 2024
1 parent 518de71 commit bb9d3a3
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions src/spherinator/models/rotational_variational_autoencoder_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,27 +94,22 @@ def forward(self, x):
return (z_location, z_scale), (q_z, p_z), z, recon

def training_step(self, batch, batch_idx):
with torch.autocast("cuda", enabled=False):
best_scaled_image, _, _, _ = self.find_best_rotation(batch)

with torch.autocast("cuda", enabled=True):
(z_location, z_scale), (q_z, p_z), _, recon = self.forward(
best_scaled_image
)

loss_recon = self.reconstruction_loss(best_scaled_image, recon)
loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z) * self.beta
loss = (loss_recon + loss_KL).mean()
loss_recon = loss_recon.mean()
loss_KL = loss_KL.mean()

self.log("train_loss", loss, prog_bar=True)
self.log("loss_recon", loss_recon, prog_bar=True)
self.log("loss_KL", loss_KL)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
self.log("mean(z_location)", torch.mean(z_location))
self.log("mean(z_scale)", torch.mean(z_scale))
return loss
best_scaled_image, _, _, _ = self.find_best_rotation(batch)
(z_location, z_scale), (q_z, p_z), _, recon = self.forward(best_scaled_image)

loss_recon = self.reconstruction_loss(best_scaled_image, recon)
loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z) * self.beta
loss = (loss_recon + loss_KL).mean()
loss_recon = loss_recon.mean()
loss_KL = loss_KL.mean()

self.log("train_loss", loss, prog_bar=True)
self.log("loss_recon", loss_recon, prog_bar=True)
self.log("loss_KL", loss_KL)
self.log("learning_rate", self.optimizers().param_groups[0]["lr"])
self.log("mean(z_location)", torch.mean(z_location))
self.log("mean(z_scale)", torch.mean(z_scale))
return loss

def configure_optimizers(self):
"""Default Adam optimizer if missing from the configuration file."""
Expand Down

0 comments on commit bb9d3a3

Please sign in to comment.