-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from BerndDoser/callback
Monitoring reconstructed images during training
- Loading branch information
Showing
18 changed files
with
240 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
PyTorch Lightning callbacks | ||
""" | ||
|
||
from .log_reconstruction_callback import LogReconstructionCallback | ||
|
||
__all__ = [ | ||
'LogReconstructionCallback', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import matplotlib.pyplot as plt | ||
import torch | ||
import torchvision.transforms.functional as functional | ||
from lightning.pytorch.callbacks import Callback | ||
|
||
|
||
class LogReconstructionCallback(Callback): | ||
def __init__(self, num_samples=4): | ||
super().__init__() | ||
self.num_samples = num_samples | ||
|
||
def on_train_epoch_end(self, trainer, pl_module): | ||
|
||
# Return if no wandb logger is used | ||
if trainer.logger is None or trainer.logger.__class__.__name__ not in ["WandbLogger", "MyLogger"]: | ||
return | ||
|
||
# Generate some random samples from the validation set | ||
samples = next(iter(trainer.train_dataloader))['image'] | ||
samples = samples[:self.num_samples] | ||
samples = samples.to(pl_module.device) | ||
|
||
# Generate reconstructions of the samples using the model | ||
with torch.no_grad(): | ||
batch_size = samples.shape[0] | ||
losses = torch.zeros(batch_size, pl_module.rotations) | ||
images = torch.zeros((batch_size, 3, pl_module.input_size, pl_module.input_size, pl_module.rotations)) | ||
recons = torch.zeros((batch_size, 3, pl_module.input_size, pl_module.input_size, pl_module.rotations)) | ||
coords = torch.zeros((batch_size, 3, pl_module.rotations)) | ||
for r in range(pl_module.rotations): | ||
rotate = functional.rotate(samples, 360.0 / pl_module.rotations * r, expand=False) | ||
crop = functional.center_crop(rotate, [pl_module.crop_size, pl_module.crop_size]) | ||
scaled = functional.resize(crop, [pl_module.input_size, pl_module.input_size], antialias=False) | ||
|
||
(z_mean, _), (_ ,_), _, recon = pl_module(scaled) | ||
|
||
losses[:,r] = pl_module.reconstruction_loss(scaled, recon) | ||
images[:,:,:,:,r] = scaled | ||
recons[:,:,:,:,r] = recon | ||
coords[:,:,r] = z_mean | ||
|
||
min_idx = torch.min(losses, dim=1)[1] | ||
|
||
# Plot the original samples and their reconstructions side by side | ||
fig, axs = plt.subplots(self.num_samples, 2, figsize=(6, 2*self.num_samples)) | ||
for i in range(self.num_samples): | ||
axs[i, 0].imshow(images[i,:,:,:,min_idx[i]].cpu().detach().numpy().T) | ||
axs[i, 0].set_title("Original") | ||
axs[i, 0].axis("off") | ||
axs[i, 1].imshow(recons[i,:,:,:,min_idx[i]].cpu().detach().numpy().T) | ||
axs[i, 1].set_title("Reconstruction") | ||
axs[i, 1].axis("off") | ||
plt.tight_layout() | ||
|
||
# Log the figure at W&B | ||
trainer.logger.log_image(key="Reconstructions", images=[fig]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
trainer: | ||
logger: | ||
class_path: lightning.pytorch.loggers.WandbLogger | ||
init_args: | ||
project: spherinator | ||
log_model: all | ||
entity: ain-space | ||
tags: | ||
- log-reconstructions | ||
callbacks: | ||
- class_path: callbacks.LogReconstructionCallback | ||
init_args: | ||
num_samples: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
astropy==5.3.4 | ||
lightning==2.1.0 | ||
matplotlib==3.7.2 | ||
scikit-image==0.22.0 | ||
scipy==1.11.3 | ||
torch==2.1.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import matplotlib.pyplot as plt | ||
from lightning.pytorch.loggers import Logger | ||
from lightning.pytorch.trainer import Trainer | ||
from lightning.pytorch.utilities import rank_zero_only | ||
|
||
from callbacks import LogReconstructionCallback | ||
from data import ShapesDataModule | ||
from models import RotationalVariationalAutoencoderPower | ||
|
||
|
||
class MyLogger(Logger): | ||
|
||
@property | ||
def name(self): | ||
return "MyLogger" | ||
|
||
@property | ||
def version(self): | ||
return "0.1" | ||
|
||
@rank_zero_only | ||
def log_hyperparams(self, params): | ||
pass | ||
|
||
@rank_zero_only | ||
def log_metrics(self, metrics, step): | ||
pass | ||
|
||
@rank_zero_only | ||
def save(self): | ||
pass | ||
|
||
@rank_zero_only | ||
def finalize(self, status): | ||
pass | ||
|
||
def __init__(self): | ||
self.calls = 0 | ||
self.logged_items = [] | ||
|
||
def log_image(self, key, images): | ||
self.calls += 1 | ||
self.logged_items.append((key, images)) | ||
|
||
|
||
def test_on_train_epoch_end(): | ||
|
||
# Set up the model and dataloader | ||
z_dim = 3 | ||
model = RotationalVariationalAutoencoderPower(z_dim=z_dim) | ||
|
||
datamodule = ShapesDataModule("tests/data/shapes", num_workers=1, batch_size=12) | ||
datamodule.setup("fit") | ||
# data_loader = data_module.train_dataloader() | ||
|
||
logger = MyLogger() | ||
|
||
trainer = Trainer(max_epochs=1, logger=logger, overfit_batches = 2) | ||
trainer.fit(model, datamodule=datamodule) | ||
|
||
# Set up the callback | ||
num_samples = 2 | ||
callback = LogReconstructionCallback(num_samples=num_samples) | ||
|
||
# Call the callback | ||
callback.on_train_epoch_end(trainer=trainer, pl_module=model) | ||
|
||
logger.finalize("success") | ||
|
||
# Check that the figure was logged | ||
assert logger.calls == 1 | ||
assert "Reconstructions" in logger.logged_items[0][0] | ||
assert isinstance(logger.logged_items[0][1][0], plt.Figure) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters