Skip to content

Commit

Permalink
Merge pull request #37 from BerndDoser/kl
Browse files Browse the repository at this point in the history
Fix KL divergence of power spherical distribution
  • Loading branch information
BerndDoser authored Oct 28, 2023
2 parents 658b8f8 + e52aed9 commit a5bfda5
Show file tree
Hide file tree
Showing 14 changed files with 985 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__pycache__
*.ckpt
*.ncu-rep
config.yaml
HiPSter/
lightning_logs/
local/
wandb/
175 changes: 175 additions & 0 deletions devel/check-illustris-power-reconstruction.ipynb

Large diffs are not rendered by default.

175 changes: 175 additions & 0 deletions devel/check-shapes-power-reconstruction.ipynb

Large diffs are not rendered by default.

247 changes: 247 additions & 0 deletions devel/power-kl-divergence.ipynb

Large diffs are not rendered by default.

168 changes: 168 additions & 0 deletions devel/svae-kl-divergence.ipynb

Large diffs are not rendered by default.

62 changes: 62 additions & 0 deletions experiments/illustris-power.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
seed_everything: 42

model:
class_path: models.RotationalVariationalAutoencoderPower
init_args:
h_dim: 256
z_dim: 3
image_size: 363
rotations: 36
beta: 1.0e-3

data:
class_path: data.IllustrisSdssDataModule
init_args:
data_directories: ["/local_data/AIN/SKIRT_synthetic_images/TNG100/sdss/snapnum_099/data/",
"/local_data/AIN/SKIRT_synthetic_images/TNG100/sdss/snapnum_095/data/",
"/local_data/AIN/SKIRT_synthetic_images/TNG50/sdss/snapnum_099/data/",
"/local_data/AIN/SKIRT_synthetic_images/TNG50/sdss/snapnum_095/data/",
"/local_data/AIN/SKIRT_synthetic_images/Illustris/sdss/snapnum_135/data/",
"/local_data/AIN/SKIRT_synthetic_images/Illustris/sdss/snapnum_131/data/"]
extension: fits
minsize: 100
batch_size: 128
shuffle: True
num_workers: 32

optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.001

lr_scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: min
factor: 0.1
patience: 5
cooldown: 5
min_lr: 1.e-5
monitor: train_loss
verbose: True

trainer:
max_epochs: -1
accelerator: gpu
devices: [1]
precision: 32
# overfit_batches: 1
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: spherinator
name: illustris-power
log_model: True
# callbacks:
# - class_path: lightning.pytorch.callbacks.ModelCheckpoint
# init_args:
# monitor: train_loss
# filename: "{epoch}-{train_loss:.2f}"
# save_top_k: 1
# mode: min
# every_n_epochs: 1
4 changes: 2 additions & 2 deletions experiments/shapes-power.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ model:
z_dim: 3
image_size: 91
rotations: 36
beta: 2.0
beta: 0.001

data:
class_path: data.ShapesDataModule
Expand Down Expand Up @@ -37,7 +37,7 @@ lr_scheduler:
trainer:
max_epochs: -1
accelerator: gpu
devices: 1
devices: [2]
precision: 32
logger:
class_path: lightning.pytorch.loggers.WandbLogger
Expand Down
2 changes: 2 additions & 0 deletions models/rotational_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def __init__(self,
super(RotationalAutoencoder, self).__init__()
self.bottleneck = bottleneck
self.rotations = rotations
self.input_size = 128
self.example_input_array = torch.randn(1, bottleneck, self.input_size, self.input_size)
self.conv0 = nn.Conv2d(in_channels=3, out_channels=16,
kernel_size=(3,3), stride=1, padding=1) #128x128
self.pool0 = nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=0) # 64x64
Expand Down
78 changes: 55 additions & 23 deletions models/rotational_variational_autoencoder_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,61 +43,85 @@ def __init__(self,
self.beta = beta

self.crop_size = int(self.image_size * math.sqrt(2) / 2)
self.input_size = 64

if self.input_size > self.crop_size:
raise ValueError("Image size to small.")
self.input_size = 128
self.total_input_size = self.input_size * self.input_size * 3

self.example_input_array = torch.randn(1, 3, self.input_size, self.input_size)

self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(5,5), stride=2, padding=2)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(5,5), stride=2, padding=2)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(5,5), stride=2, padding=2)
self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(5,5), stride=2, padding=2)
self.conv0 = nn.Conv2d(in_channels=3, out_channels=16,
kernel_size=(3,3), stride=1, padding=1) #128x128
self.pool0 = nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=0) # 64x64
self.conv1 = nn.Conv2d(in_channels=16, out_channels=32,
kernel_size=(3,3), stride=1, padding=1) #64x64
self.pool1 = nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=0) # 32x32
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64,
kernel_size=(3,3), stride=1, padding=1) #32x32
self.pool2 = nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=0) # 16x16
self.conv3 = nn.Conv2d(in_channels=64, out_channels=128,
kernel_size=(3,3), stride=1, padding=1) #16x16
self.pool3 = nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=0) # 8x8
self.conv4 = nn.Conv2d(in_channels=128, out_channels=256,
kernel_size=(3,3), stride=1, padding=1) #8x8
self.pool4 = nn.MaxPool2d(kernel_size=(2,2), stride=2, padding=0) # 4x4

self.fc1 = nn.Linear(256*4*4, h_dim)
self.fc_mean = nn.Linear(h_dim, z_dim)
self.fc_var = nn.Linear(h_dim, 1)
self.fc2 = nn.Linear(z_dim, h_dim)
self.fc3 = nn.Linear(h_dim, 256*4*4)

self.deconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=(4,4), stride=2, padding=1)
self.deconv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(4,4), stride=2, padding=1)
self.deconv3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(4,4), stride=2, padding=1)
self.deconv4 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(4,4), stride=2, padding=1)
self.deconv5 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(5,5), stride=1, padding=2)
self.deconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128,
kernel_size=(4,4), stride=2, padding=1) #8x8
self.deconv2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
kernel_size=(4,4), stride=2, padding=1) #16x16
self.deconv3 = nn.ConvTranspose2d(in_channels=128, out_channels=64,
kernel_size=(4,4), stride=2, padding=1) #32x32
self.deconv4 = nn.ConvTranspose2d(in_channels=64, out_channels=32,
kernel_size=(4,4), stride=2, padding=1) #64x64
self.deconv5 = nn.ConvTranspose2d(in_channels=32, out_channels=16,
kernel_size=(3,3), stride=2, padding=1) #127x127
self.deconv6 = nn.ConvTranspose2d(in_channels=16, out_channels=3,
kernel_size=(2,2), stride=1, padding=0) #128x128

def encode(self, x):
x = F.relu(self.conv0(x))
x = self.pool0(x)
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = F.relu(self.conv3(x))
x = self.pool3(x)
x = F.relu(self.conv4(x))

x = self.pool4(x)
x = x.view(-1, 256*4*4)
x = F.tanh(self.fc1(x))
x = F.sigmoid(self.fc1(x))

z_mean = self.fc_mean(x)
z_mean = torch.nn.functional.normalize(z_mean, p=2.0, dim=1)
# SVAE code: the `+ 1` prevent collapsing behaviors
z_var = F.softplus(self.fc_var(x)) + 1
# z_var = F.softplus(self.fc_var(x)) + 0.1
z_var = torch.exp(self.fc_var(x)) + 20.0

return z_mean, z_var

def decode(self, z):
x = F.tanh(self.fc2(z))
x = F.tanh(self.fc3(x))
x = x.view(-1, 256, 4, 4)

x = F.relu(self.deconv1(x))
x = F.relu(self.deconv2(x))
x = F.relu(self.deconv3(x))
x = F.relu(self.deconv4(x))
x = self.deconv5(x)
x = torch.sigmoid(x)
x = F.relu(self.deconv5(x))
x = self.deconv6(x)
# x = torch.sigmoid(x)
return x

def reparameterize(self, z_mean, z_var):
q_z = PowerSpherical(z_mean, z_var)
p_z = HypersphericalUniform(self.z_dim - 1, device=z_mean.device)
p_z = HypersphericalUniform(self.z_dim, device=z_mean.device)
return q_z, p_z

def forward(self, x):
Expand All @@ -112,15 +136,17 @@ def training_step(self, batch, batch_idx):
losses = torch.zeros(images.shape[0], self.rotations)
losses_recon = torch.zeros(images.shape[0], self.rotations)
losses_KL = torch.zeros(images.shape[0], self.rotations)
z_mean = torch.zeros(self.z_dim)
z_scale = torch.zeros(self.z_dim)
for i in range(self.rotations):
rotate = functional.rotate(images, 360.0 / self.rotations * i, expand=False)
crop = functional.center_crop(rotate, [self.crop_size, self.crop_size])
scaled = functional.resize(crop, [self.input_size, self.input_size], antialias=False)

(_, _), (q_z, p_z), _, recon = self.forward(scaled)
(z_mean, z_scale), (q_z, p_z), _, recon = self.forward(scaled)

loss_recon = self.reconstruction_loss(scaled, recon)
loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z).mean()
loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z)

losses[:,i] = loss_recon + self.beta * loss_KL
losses_recon[:,i] = loss_recon
Expand All @@ -134,6 +160,8 @@ def training_step(self, batch, batch_idx):
self.log('loss_recon', loss_recon, prog_bar=True)
self.log('loss_KL', loss_KL)
self.log('learning_rate', self.optimizers().param_groups[0]['lr'])
self.log('mean(z_mean) ', torch.mean(z_mean))
self.log('mean(z_scale) ', torch.mean(z_scale))
return loss

def configure_optimizers(self):
Expand All @@ -148,5 +176,9 @@ def reconstruct(self, coordinates):
return self.decode(coordinates)

def reconstruction_loss(self, images, reconstructions):
return nn.MSELoss(reduction='none')(
reconstructions.reshape(-1, 3*64*64), images.reshape(-1, 3*64*64)).sum(-1).mean()
return torch.sqrt(nn.MSELoss(reduction='none')(
reconstructions.reshape(-1, self.total_input_size),
images.reshape(-1, self.total_input_size)).mean(-1))
# return nn.MSELoss(reduction='none')(
# reconstructions.reshape(-1, self.total_input_size),
# images.reshape(-1, self.total_input_size)).sum(-1)
17 changes: 16 additions & 1 deletion tests/test_power_spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(script_dir, '../external/power_spherical/'))
from power_spherical import PowerSpherical
from power_spherical import PowerSpherical, HypersphericalUniform


def test_power_spherical_2d():
Expand All @@ -24,3 +24,18 @@ def test_power_spherical_2d_batch():

sample = dist.rsample()
assert sample.shape == torch.Size([batch_size, 3])

def test_kl_divergence():
dim = 8
loc = torch.tensor([0.] * (dim - 1) + [1.])
scale = torch.tensor(10.)

dist1 = PowerSpherical(loc, scale)
dist2 = HypersphericalUniform(dim)
x = dist1.sample((100000,))

assert torch.isclose(
(dist1.log_prob(x) - dist2.log_prob(x)).mean(),
torch.distributions.kl_divergence(dist1, dist2),
atol=1e-2,
).all()
25 changes: 25 additions & 0 deletions tests/test_rotational_autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from models import RotationalAutoencoder
import torch

def test_forward():

model = RotationalAutoencoder()
input = model.example_input_array

recon, coord = model(input)

assert coord.shape == (1,3)
assert recon.shape == input.shape

def test_reconstruction_loss():

torch.manual_seed(0)
model = RotationalAutoencoder()
image1 = torch.zeros((2,3,64,64))
image2 = torch.ones((2,3,64,64))
image3 = torch.zeros((2,3,64,64))
image3[0,0,0,0] = 1.0

assert torch.isclose(model.reconstruction_loss(image1, image1), torch.Tensor([0., 0.]), atol = 1e-3).all()
assert torch.isclose(model.reconstruction_loss(image1, image2), torch.Tensor([1., 1.]), atol = 1e-3).all()
assert torch.isclose(model.reconstruction_loss(image1, image3), torch.Tensor([0.009, 0.]), atol = 1e-3).all()
17 changes: 16 additions & 1 deletion tests/test_rotational_variational_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from models import RotationalVariationalAutoencoder
import torch

def test_rotational_variational_autoencoder():
def test_forward():

z_dim = 2
model = RotationalVariationalAutoencoder(z_dim=z_dim)
Expand All @@ -12,3 +13,17 @@ def test_rotational_variational_autoencoder():
assert z_mean.shape == (batch_size, z_dim)
assert z_var.shape == (batch_size, z_dim)
assert recon.shape == input.shape

def test_reconstruction_loss():

torch.manual_seed(0)
z_dim = 2
model = RotationalVariationalAutoencoder(z_dim=z_dim)
image1 = torch.zeros((2,3,64,64))
image2 = torch.ones((2,3,64,64))
image3 = torch.zeros((2,3,64,64))
image3[0,0,0,0] = 1.0

assert model.reconstruction_loss(image1, image1) == 0.0
assert torch.isclose(model.reconstruction_loss(image1, image2), torch.Tensor([3*64*64]), rtol = 1e-3)
assert torch.isclose(model.reconstruction_loss(image1, image3), torch.Tensor([0.5]), rtol = 1e-3)
29 changes: 29 additions & 0 deletions tests/test_rotational_variational_autoencoder_power.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from models import RotationalVariationalAutoencoderPower
import torch

def test_forward():

z_dim = 2
model = RotationalVariationalAutoencoderPower(z_dim=z_dim)
input = model.example_input_array
batch_size = input.shape[0]

(z_mean, z_var), (_, _), _, recon = model(input)

assert z_mean.shape == (batch_size, z_dim)
assert z_var.shape == (batch_size, 1)
assert recon.shape == input.shape

def test_reconstruction_loss():

torch.manual_seed(0)
z_dim = 2
model = RotationalVariationalAutoencoderPower(z_dim=z_dim)
image1 = torch.zeros((2,3,128,128))
image2 = torch.ones((2,3,128,128))
image3 = torch.zeros((2,3,128,128))
image3[0,0,0,0] = 1.0

assert torch.isclose(model.reconstruction_loss(image1, image1), torch.Tensor([0., 0.]), atol = 1e-3).all()
assert torch.isclose(model.reconstruction_loss(image1, image2), torch.Tensor([1., 1.]), atol = 1e-3).all()
assert torch.isclose(model.reconstruction_loss(image1, image3), torch.Tensor([0.009, 0.]), atol = 1e-2).all()
12 changes: 12 additions & 0 deletions tests/test_torch_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
import torch.nn as nn

def test_MSELoss():

image1 = torch.Tensor([0.0])
image2 = torch.Tensor([0.1])

loss = nn.MSELoss(reduction='none')

assert loss(image1, image1).mean() == 0.0
assert torch.isclose(loss(image1, image2).mean(), torch.Tensor([0.01]), rtol = 1e-3)

0 comments on commit a5bfda5

Please sign in to comment.