Skip to content

Commit

Permalink
Merge pull request #42 from BerndDoser/callback
Browse files Browse the repository at this point in the history
Monitoring reconstructed images during training
  • Loading branch information
BerndDoser authored Nov 9, 2023
2 parents 91c973c + 3e0e9d1 commit 7f8d8e6
Show file tree
Hide file tree
Showing 18 changed files with 240 additions and 28 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,12 @@ The following command generates a HiPS representation and a catalog showing the
```

Call `./hipster.py --help` for more information.


## Visualize reconstructed images during training

The config-file [wandb-log-reconstructions.yaml](experiments/wandb-log-reconstructions.yaml) can be appended to visualize the reconstructed images during training at W&B.

```bash
python main.py fit -c experiments/illustris.yaml -c experiments/wandb-log-reconstructions.yaml
```
9 changes: 9 additions & 0 deletions callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
PyTorch Lightning callbacks
"""

from .log_reconstruction_callback import LogReconstructionCallback

__all__ = [
'LogReconstructionCallback',
]
56 changes: 56 additions & 0 deletions callbacks/log_reconstruction_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import matplotlib.pyplot as plt
import torch
import torchvision.transforms.functional as functional
from lightning.pytorch.callbacks import Callback


class LogReconstructionCallback(Callback):
def __init__(self, num_samples=4):
super().__init__()
self.num_samples = num_samples

def on_train_epoch_end(self, trainer, pl_module):

# Return if no wandb logger is used
if trainer.logger is None or trainer.logger.__class__.__name__ not in ["WandbLogger", "MyLogger"]:
return

# Generate some random samples from the validation set
samples = next(iter(trainer.train_dataloader))['image']
samples = samples[:self.num_samples]
samples = samples.to(pl_module.device)

# Generate reconstructions of the samples using the model
with torch.no_grad():
batch_size = samples.shape[0]
losses = torch.zeros(batch_size, pl_module.rotations)
images = torch.zeros((batch_size, 3, pl_module.input_size, pl_module.input_size, pl_module.rotations))
recons = torch.zeros((batch_size, 3, pl_module.input_size, pl_module.input_size, pl_module.rotations))
coords = torch.zeros((batch_size, 3, pl_module.rotations))
for r in range(pl_module.rotations):
rotate = functional.rotate(samples, 360.0 / pl_module.rotations * r, expand=False)
crop = functional.center_crop(rotate, [pl_module.crop_size, pl_module.crop_size])
scaled = functional.resize(crop, [pl_module.input_size, pl_module.input_size], antialias=False)

(z_mean, _), (_ ,_), _, recon = pl_module(scaled)

losses[:,r] = pl_module.reconstruction_loss(scaled, recon)
images[:,:,:,:,r] = scaled
recons[:,:,:,:,r] = recon
coords[:,:,r] = z_mean

min_idx = torch.min(losses, dim=1)[1]

# Plot the original samples and their reconstructions side by side
fig, axs = plt.subplots(self.num_samples, 2, figsize=(6, 2*self.num_samples))
for i in range(self.num_samples):
axs[i, 0].imshow(images[i,:,:,:,min_idx[i]].cpu().detach().numpy().T)
axs[i, 0].set_title("Original")
axs[i, 0].axis("off")
axs[i, 1].imshow(recons[i,:,:,:,min_idx[i]].cpu().detach().numpy().T)
axs[i, 1].set_title("Reconstruction")
axs[i, 1].axis("off")
plt.tight_layout()

# Log the figure at W&B
trainer.logger.log_image(key="Reconstructions", images=[fig])
17 changes: 12 additions & 5 deletions data/shapes_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@ def __init__(self,
shuffle: bool = True,
image_size: int = 91,
batch_size: int = 32,
num_workers: int = 1):
num_workers: int = 1,
download: bool = False):
""" Initializes the data loader
Args:
data_directories (List[str]): The data directory
shuffle (bool, optional): Wether or not to shuffle whe reading. Defaults to True.
image_size (int, optional): The size of the images. Defaults to 91.
batch_size (int, optional): The batch size for training. Defaults to 32.
num_workers (int, optional): How many worker to use for loading. Defaults to 1.
download (bool, optional): Wether or not to download the data. Defaults to False.
"""
super().__init__()

Expand All @@ -32,11 +35,12 @@ def __init__(self,
self.image_size = image_size
self.batch_size = batch_size
self.num_workers = num_workers
self.download = download

self.transform_train = transforms.Compose([
transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
transforms.Normalize((0,0,0), (290,290,290)),
transforms.Resize((self.image_size, self.image_size), antialias=True)
transforms.Resize((self.image_size, self.image_size), antialias="none"),
])
self.transform_predict = self.transform_train
self.transform_val = self.transform_train
Expand All @@ -57,15 +61,17 @@ def setup(self, stage: str):
"""
if stage == "fit":
self.data_train = ShapesDataset(data_directory=self.data_directory,
transform=self.transform_train)
transform=self.transform_train,
download=self.download)

self.dataloader_train = DataLoader(self.data_train,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers)
if stage == "predict":
self.data_predict = ShapesDataset(data_directory=self.data_directory,
transform=self.transform_predict)
transform=self.transform_predict,
download=self.download)

self.dataloader_predict = DataLoader(self.data_predict,
batch_size=self.batch_size,
Expand All @@ -74,7 +80,8 @@ def setup(self, stage: str):

if stage == "val":
self.data_val = ShapesDataset(data_directory=self.data_directory,
transform=self.transform_val)
transform=self.transform_val,
download=self.download)

self.dataloader_val = DataLoader(self.data_val,
batch_size=self.batch_size,
Expand Down
9 changes: 8 additions & 1 deletion data/shapes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@ class ShapesDataset(Dataset):
"""
def __init__(self,
data_directory: str,
transform = None):
transform = None,
download: bool = False):
""" Initializes an Illustris sdss data set.
Args:
data_directory (str): The data directory.
transform (torchvision.transforms.Compose, optional): A single or a set of
transformations to modify the images. Defaults to None.
download (bool, optional): Wether or not to download the data. Defaults to False.
"""
self.data_directory = data_directory
self.transform = transform
self.download = download

if self.download:
raise NotImplementedError("Download not implemented yet.")

self.images = np.empty((0,64,64), np.float32)
for file in os.listdir(data_directory):
self.images = np.append(self.images, np.load(os.path.join(data_directory, file)),
Expand Down
38 changes: 27 additions & 11 deletions devel/power-kl-divergence.ipynb

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions experiments/shapes-power.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@ lr_scheduler:
trainer:
max_epochs: -1
accelerator: gpu
devices: [2]
devices: 1
precision: 32
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
name: shapes-power
project: profiling
# name: shapes-power
log_model: True
callbacks:
- class_path: callbacks.LogReconstructionCallback
init_args:
num_samples: 4
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: train_loss
Expand Down
13 changes: 13 additions & 0 deletions experiments/wandb-log-reconstructions.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
trainer:
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
log_model: all
entity: ain-space
tags:
- log-reconstructions
callbacks:
- class_path: callbacks.LogReconstructionCallback
init_args:
num_samples: 4
5 changes: 3 additions & 2 deletions models/rotational_variational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,6 @@ def reconstruct(self, coordinates):
return self.decode(coordinates)

def reconstruction_loss(self, images, reconstructions):
return nn.MSELoss(reduction='none')(
reconstructions.reshape(-1, 3*64*64), images.reshape(-1, 3*64*64)).sum(-1).mean()
return torch.sqrt(nn.MSELoss(reduction='none')(
reconstructions.reshape(-1, self.total_input_size),
images.reshape(-1, self.total_input_size)).mean(dim=1))
3 changes: 0 additions & 3 deletions models/rotational_variational_autoencoder_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,3 @@ def reconstruction_loss(self, images, reconstructions):
return torch.sqrt(nn.MSELoss(reduction='none')(
reconstructions.reshape(-1, self.total_input_size),
images.reshape(-1, self.total_input_size)).mean(dim=1))
# return nn.MSELoss(reduction='none')(
# reconstructions.reshape(-1, self.total_input_size),
# images.reshape(-1, self.total_input_size)).sum(-1)
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
astropy==5.3.4
lightning==2.1.0
matplotlib==3.7.2
scikit-image==0.22.0
scipy==1.11.3
torch==2.1.0
Expand Down
20 changes: 20 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ charset-normalizer==3.3.0
# via
# aiohttp
# requests
contourpy==1.2.0
# via matplotlib
cycler==0.12.1
# via matplotlib
filelock==3.12.4
# via
# torch
# triton
fonttools==4.44.0
# via matplotlib
frozenlist==1.4.0
# via
# aiohttp
Expand All @@ -41,6 +47,8 @@ imageio==2.31.5
# via scikit-image
jinja2==3.1.2
# via torch
kiwisolver==1.4.5
# via matplotlib
lazy-loader==0.3
# via scikit-image
lightning==2.1.0
Expand All @@ -52,6 +60,8 @@ lightning-utilities==0.9.0
# torchmetrics
markupsafe==2.1.3
# via jinja2
matplotlib==3.7.2
# via -r requirements.in
mpmath==1.3.0
# via sympy
multidict==6.0.4
Expand All @@ -65,8 +75,10 @@ networkx==3.1
numpy==1.26.0
# via
# astropy
# contourpy
# imageio
# lightning
# matplotlib
# pyerfa
# pytorch-lightning
# scikit-image
Expand Down Expand Up @@ -110,15 +122,21 @@ packaging==23.2
# astropy
# lightning
# lightning-utilities
# matplotlib
# pytorch-lightning
# scikit-image
pillow==10.0.1
# via
# imageio
# matplotlib
# scikit-image
# torchvision
pyerfa==2.0.1
# via astropy
pyparsing==3.0.9
# via matplotlib
python-dateutil==2.8.2
# via matplotlib
pytorch-lightning==2.0.9.post0
# via lightning
pyyaml==6.0.1
Expand All @@ -136,6 +154,8 @@ scipy==1.11.3
# via
# -r requirements.in
# scikit-image
six==1.16.0
# via python-dateutil
sympy==1.12
# via torch
tifffile==2023.9.26
Expand Down
Binary file added tests/data/shapes/boxes.npy
Binary file not shown.
Binary file added tests/data/shapes/circles.npy
Binary file not shown.
Binary file added tests/data/shapes/crosses.npy
Binary file not shown.
Binary file added tests/data/shapes/triangles.npy
Binary file not shown.
73 changes: 73 additions & 0 deletions tests/test_log_reconstruction_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import matplotlib.pyplot as plt
from lightning.pytorch.loggers import Logger
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.utilities import rank_zero_only

from callbacks import LogReconstructionCallback
from data import ShapesDataModule
from models import RotationalVariationalAutoencoderPower


class MyLogger(Logger):

@property
def name(self):
return "MyLogger"

@property
def version(self):
return "0.1"

@rank_zero_only
def log_hyperparams(self, params):
pass

@rank_zero_only
def log_metrics(self, metrics, step):
pass

@rank_zero_only
def save(self):
pass

@rank_zero_only
def finalize(self, status):
pass

def __init__(self):
self.calls = 0
self.logged_items = []

def log_image(self, key, images):
self.calls += 1
self.logged_items.append((key, images))


def test_on_train_epoch_end():

# Set up the model and dataloader
z_dim = 3
model = RotationalVariationalAutoencoderPower(z_dim=z_dim)

datamodule = ShapesDataModule("tests/data/shapes", num_workers=1, batch_size=12)
datamodule.setup("fit")
# data_loader = data_module.train_dataloader()

logger = MyLogger()

trainer = Trainer(max_epochs=1, logger=logger, overfit_batches = 2)
trainer.fit(model, datamodule=datamodule)

# Set up the callback
num_samples = 2
callback = LogReconstructionCallback(num_samples=num_samples)

# Call the callback
callback.on_train_epoch_end(trainer=trainer, pl_module=model)

logger.finalize("success")

# Check that the figure was logged
assert logger.calls == 1
assert "Reconstructions" in logger.logged_items[0][0]
assert isinstance(logger.logged_items[0][1][0], plt.Figure)
6 changes: 3 additions & 3 deletions tests/test_rotational_variational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def test_reconstruction_loss():
image3 = torch.zeros((2,3,64,64))
image3[0,0,0,0] = 1.0

assert model.reconstruction_loss(image1, image1) == 0.0
assert torch.isclose(model.reconstruction_loss(image1, image2), torch.Tensor([3*64*64]), rtol = 1e-3)
assert torch.isclose(model.reconstruction_loss(image1, image3), torch.Tensor([0.5]), rtol = 1e-3)
assert torch.isclose(model.reconstruction_loss(image1, image1), torch.Tensor([0., 0.]), atol = 1e-3).all()
assert torch.isclose(model.reconstruction_loss(image1, image2), torch.Tensor([1., 1.]), atol = 1e-3).all()
assert torch.isclose(model.reconstruction_loss(image1, image3), torch.Tensor([0.009, 0.]), atol = 1e-2).all()

0 comments on commit 7f8d8e6

Please sign in to comment.