Skip to content

Commit

Permalink
Migrate GauGAN to Keras3[TF Backend Only] (#1722)
Browse files Browse the repository at this point in the history
* Keras 3 Migration

* Keras3 Migration

* Regenerate md files
  • Loading branch information
sineeli committed Jan 11, 2024
1 parent ac7f64c commit acf47ee
Show file tree
Hide file tree
Showing 27 changed files with 332 additions and 199 deletions.
107 changes: 62 additions & 45 deletions examples/generative/gaugan.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,29 @@
We will be using the
[Facades dataset](https://cmp.felk.cvut.cz/~tylecr1/facade/)
for training our GauGAN model. Let's first download it. We also install
TensorFlow Addons.
for training our GauGAN model. Let's first download it.
"""

"""shell
gdown https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj
wget https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj -O facades_data.zip
unzip -q facades_data.zip
pip install -qqq tensorflow_addons
"""

"""
## Imports
"""
import os

os.environ["KERAS_BACKEND"] = "tensorflow"


import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow import keras
from tensorflow.keras import layers
import keras
from keras import ops
from keras import layers

from glob import glob

Expand Down Expand Up @@ -150,6 +152,11 @@ def _load_data_tf(image_file, segmentation_map_file, label_file):
segmentation_map = tf.cast(segmentation_map, tf.float32) / 127.5 - 1
return segmentation_map, image, labels

def _one_hot(segmentation_maps, real_images, labels):
labels = tf.one_hot(labels, NUM_CLASSES)
labels.set_shape((None, None, NUM_CLASSES))
return segmentation_maps, real_images, labels

segmentation_map_files = [
image_file.replace("images", "segmentation_map").replace("jpg", "png")
for image_file in image_files
Expand All @@ -165,10 +172,9 @@ def _load_data_tf(image_file, segmentation_map_file, label_file):
dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
dataset = dataset.map(_load_data_tf, num_parallel_calls=AUTOTUNE)
dataset = dataset.map(_random_crop, num_parallel_calls=AUTOTUNE)
dataset = dataset.map(
lambda x, y, z: (x, y, tf.one_hot(z, NUM_CLASSES)), num_parallel_calls=AUTOTUNE
)
return dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.map(_one_hot, num_parallel_calls=AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset


train_dataset = load(train_files, batch_size=BATCH_SIZE, is_train=True)
Expand Down Expand Up @@ -242,12 +248,12 @@ def build(self, input_shape):
self.resize_shape = input_shape[1:3]

def call(self, input_tensor, raw_mask):
mask = tf.image.resize(raw_mask, self.resize_shape, method="nearest")
mask = ops.image.resize(raw_mask, self.resize_shape, interpolation="nearest")
x = self.conv(mask)
gamma = self.conv_gamma(x)
beta = self.conv_beta(x)
mean, var = tf.nn.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
std = tf.sqrt(var + self.epsilon)
mean, var = ops.moments(input_tensor, axes=(0, 1, 2), keepdims=True)
std = ops.sqrt(var + self.epsilon)
normalized = (input_tensor - mean) / std
output = gamma * normalized + beta
return output
Expand All @@ -273,11 +279,13 @@ def build(self, input_shape):

def call(self, input_tensor, mask):
x = self.spade_1(input_tensor, mask)
x = self.conv_1(tf.nn.leaky_relu(x, 0.2))
x = self.conv_1(keras.activations.leaky_relu(x, 0.2))
x = self.spade_2(x, mask)
x = self.conv_2(tf.nn.leaky_relu(x, 0.2))
x = self.conv_2(keras.activations.leaky_relu(x, 0.2))
skip = (
self.conv_3(tf.nn.leaky_relu(self.spade_3(input_tensor, mask), 0.2))
self.conv_3(
keras.activations.leaky_relu(self.spade_3(input_tensor, mask), 0.2)
)
if self.learned_skip
else input_tensor
)
Expand All @@ -290,13 +298,17 @@ def __init__(self, batch_size, latent_dim, **kwargs):
super().__init__(**kwargs)
self.batch_size = batch_size
self.latent_dim = latent_dim
self.seed_generator = keras.random.SeedGenerator(1337)

def call(self, inputs):
means, variance = inputs
epsilon = tf.random.normal(
shape=(self.batch_size, self.latent_dim), mean=0.0, stddev=1.0
epsilon = keras.random.normal(
shape=(self.batch_size, self.latent_dim),
mean=0.0,
stddev=1.0,
seed=self.seed_generator,
)
samples = means + tf.exp(0.5 * variance) * epsilon
samples = means + ops.exp(0.5 * variance) * epsilon
return samples


Expand Down Expand Up @@ -325,7 +337,7 @@ def downsample(
)
)
if apply_norm:
block.add(tfa.layers.InstanceNormalization())
block.add(layers.GroupNormalization(groups=-1))
if apply_activation:
block.add(layers.LeakyReLU(0.2))
if apply_dropout:
Expand Down Expand Up @@ -372,7 +384,7 @@ def build_encoder(image_shape, encoder_downsample_factor=64, latent_dim=256):


def build_generator(mask_shape, latent_dim=256):
latent = keras.Input(shape=(latent_dim))
latent = keras.Input(shape=(latent_dim,))
mask = keras.Input(shape=mask_shape)
x = layers.Dense(16384)(latent)
x = layers.Reshape((4, 4, 1024))(x)
Expand All @@ -388,8 +400,8 @@ def build_generator(mask_shape, latent_dim=256):
x = layers.UpSampling2D((2, 2))(x)
x = ResBlock(filters=128)(x, mask)
x = layers.UpSampling2D((2, 2))(x)
x = tf.nn.leaky_relu(x, 0.2)
output_image = tf.nn.tanh(layers.Conv2D(3, 4, padding="same")(x))
x = keras.activations.leaky_relu(x, 0.2)
output_image = keras.activations.tanh(layers.Conv2D(3, 4, padding="same")(x))
return keras.Model([latent, mask], output_image, name="generator")


Expand Down Expand Up @@ -436,11 +448,11 @@ def build_discriminator(image_shape, downsample_factor=64):


def generator_loss(y):
return -tf.reduce_mean(y)
return -ops.mean(y)


def kl_divergence_loss(mean, variance):
return -0.5 * tf.reduce_sum(1 + variance - tf.square(mean) - tf.exp(variance))
return -0.5 * ops.sum(1 + variance - ops.square(mean) - ops.exp(variance))


class FeatureMatchingLoss(keras.losses.Loss):
Expand Down Expand Up @@ -488,8 +500,7 @@ def __init__(self, **kwargs):
self.hinge_loss = keras.losses.Hinge()

def call(self, y, is_real):
label = 1.0 if is_real else -1.0
return self.hinge_loss(label, y)
return self.hinge_loss(is_real, y)


"""
Expand All @@ -504,10 +515,14 @@ def __init__(self, val_dataset, n_samples, epoch_interval=5):
self.val_images = next(iter(val_dataset))
self.n_samples = n_samples
self.epoch_interval = epoch_interval
self.seed_generator = keras.random.SeedGenerator(42)

def infer(self):
latent_vector = tf.random.normal(
shape=(self.model.batch_size, self.model.latent_dim), mean=0.0, stddev=2.0
latent_vector = keras.random.normal(
shape=(self.model.batch_size, self.model.latent_dim),
mean=0.0,
stddev=2.0,
seed=self.seed_generator,
)
return self.model.predict([latent_vector, self.val_images[2]])

Expand Down Expand Up @@ -569,11 +584,11 @@ def __init__(
self.sampler = GaussianSampler(batch_size, latent_dim)
self.patch_size, self.combined_model = self.build_combined_generator()

self.disc_loss_tracker = tf.keras.metrics.Mean(name="disc_loss")
self.gen_loss_tracker = tf.keras.metrics.Mean(name="gen_loss")
self.feat_loss_tracker = tf.keras.metrics.Mean(name="feat_loss")
self.vgg_loss_tracker = tf.keras.metrics.Mean(name="vgg_loss")
self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")
self.disc_loss_tracker = keras.metrics.Mean(name="disc_loss")
self.gen_loss_tracker = keras.metrics.Mean(name="gen_loss")
self.feat_loss_tracker = keras.metrics.Mean(name="feat_loss")
self.vgg_loss_tracker = keras.metrics.Mean(name="vgg_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

@property
def metrics(self):
Expand All @@ -596,13 +611,13 @@ def build_combined_generator(self):
self.discriminator.trainable = False
mask_input = keras.Input(shape=self.mask_shape, name="mask")
image_input = keras.Input(shape=self.image_shape, name="image")
latent_input = keras.Input(shape=(self.latent_dim), name="latent")
latent_input = keras.Input(shape=(self.latent_dim,), name="latent")
generated_image = self.generator([latent_input, mask_input])
discriminator_output = self.discriminator([image_input, generated_image])
combined_outputs = discriminator_output + [generated_image]
patch_size = discriminator_output[-1].shape[1]
combined_model = keras.Model(
[latent_input, mask_input, image_input],
[discriminator_output, generated_image],
[latent_input, mask_input, image_input], combined_outputs
)
return patch_size, combined_model

Expand All @@ -623,8 +638,8 @@ def train_discriminator(self, latent_vector, segmentation_map, real_image, label
with tf.GradientTape() as gradient_tape:
pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
pred_real = self.discriminator([segmentation_map, real_image])[-1]
loss_fake = self.discriminator_loss(pred_fake, False)
loss_real = self.discriminator_loss(pred_real, True)
loss_fake = self.discriminator_loss(pred_fake, -1.0)
loss_real = self.discriminator_loss(pred_real, 1.0)
total_loss = 0.5 * (loss_fake + loss_real)

self.discriminator.trainable = True
Expand All @@ -644,9 +659,10 @@ def train_generator(
self.discriminator.trainable = False
with tf.GradientTape() as tape:
real_d_output = self.discriminator([segmentation_map, image])
fake_d_output, fake_image = self.combined_model(
combined_outputs = self.combined_model(
[latent_vector, labels, segmentation_map]
)
fake_d_output, fake_image = combined_outputs[:-1], combined_outputs[-1]
pred = fake_d_output[-1]

# Compute generator losses.
Expand Down Expand Up @@ -702,13 +718,14 @@ def test_step(self, data):
# Calculate the losses.
pred_fake = self.discriminator([segmentation_map, fake_images])[-1]
pred_real = self.discriminator([segmentation_map, image])[-1]
loss_fake = self.discriminator_loss(pred_fake, False)
loss_real = self.discriminator_loss(pred_real, True)
loss_fake = self.discriminator_loss(pred_fake, -1.0)
loss_real = self.discriminator_loss(pred_real, 1.0)
total_discriminator_loss = 0.5 * (loss_fake + loss_real)
real_d_output = self.discriminator([segmentation_map, image])
fake_d_output, fake_image = self.combined_model(
combined_outputs = self.combined_model(
[latent_vector, labels, segmentation_map]
)
fake_d_output, fake_image = combined_outputs[:-1], combined_outputs[-1]
pred = fake_d_output[-1]
g_loss = generator_loss(pred)
kl_loss = self.kl_divergence_loss_coeff * kl_divergence_loss(mean, variance)
Expand Down Expand Up @@ -772,7 +789,7 @@ def plot_history(item):
for _ in range(5):
val_images = next(val_iterator)
# Sample latent from a normal distribution.
latent_vector = tf.random.normal(
latent_vector = keras.random.normal(
shape=(gaugan.batch_size, gaugan.latent_dim), mean=0.0, stddev=2.0
)
# Generate fake images.
Expand Down
Binary file modified examples/generative/img/gaugan/gaugan_11_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_11_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_11_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_11_4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/gaugan/gaugan_31_10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_12.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_13.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/gaugan/gaugan_31_15.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_17.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_18.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_20.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/gaugan/gaugan_31_21.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/gaugan/gaugan_31_22.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/gaugan/gaugan_31_23.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/gaugan/gaugan_31_24.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/gaugan/gaugan_31_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_31_8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_33_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/generative/img/gaugan/gaugan_33_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/generative/img/gaugan/gaugan_33_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit acf47ee

Please sign in to comment.