Skip to content

Commit

Permalink
Use real and synt data for AdobeVFR autoencoder train
Browse files Browse the repository at this point in the history
The paper mentions that, for training the autoencoder,
both real and synthetic data are used and that this
improves the performance of the final inference. Let's
do the same here.
  • Loading branch information
Dexterp37 committed Oct 7, 2023
1 parent 6207af3 commit 0ed66e4
Showing 1 changed file with 50 additions and 10 deletions.
60 changes: 50 additions & 10 deletions src/fontina/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
)
import torch

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from fontina.adobevfr_dataset import AdobeVFRDataset
from fontina.augmentation_utils import (
get_deepfont_full_augmentations,
Expand All @@ -22,6 +20,18 @@
from fontina.models.lightning_generate_callback import GenerateCallback
from fontina.models.lightning_wrappers import DeepFontAutoencoderWrapper, DeepFontWrapper

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# The torchvision `DataLoader` uses Pillow under the hood to
# load image files. However PIL will fail to load some PNG
# files by throwing a zlib decompression error:
# "ValueError: Decompressed Data Too Large". Setting
# `LOAD_TRUNCATED_IMAGES = True` mitigates this problem.
from PIL import Image, ImageFile, UnidentifiedImageError

ImageFile.LOAD_TRUNCATED_IMAGES = True


def get_parser():
parser = argparse.ArgumentParser(description="Fontina training")
Expand Down Expand Up @@ -90,31 +100,61 @@ def load_and_split_data(train_config):

def load_adobevfr_dataset(train_config):
all_train_data = AdobeVFRDataset(
f"{train_config['data_root']}/VFR_syn_train",
f"{train_config['data_root']}/BCF format/VFR_syn_train",
"train",
get_deepfont_full_augmentations(),
)
num_labels = all_train_data.num_labels

# From the DeepFont paper: "We first train the SCAE on both synthetic and
# real-world data in a unsupervised way [...]". When training the autoencoder,
# we merge the real and synthetic datasets for training purposes.
if train_config["only_autoencoder"]:

def is_valid_image(path):
try:
_ = Image.open(path)
return True
except (UnidentifiedImageError, ValueError):
print(f"Failed to load image: {path}")
return False

real_data = datasets.ImageFolder(
root=f"{train_config['data_root']}/Raw Image/VFR_real_u",
# Important: albumentation can't set grayscale and output only one
# channel, so do it here.
transform=transforms.Grayscale(num_output_channels=1),
target_transform=None,
is_valid_file=is_valid_image,
)
real_data_processed = AugmentedDataset(
real_data, 1, get_deepfont_full_augmentations()
)

# Override 'all_train_data' with the joined dataset.
all_train_data = torch.utils.data.ConcatDataset(
[all_train_data, real_data_processed]
)

# Although the AdobeVFR dataset readme says that VFR_syn_val contains
# the validation for the same classes as VFR_syn_train, that doens't
# seem to be the case: the former contains 2383 classes, the latter
# 4383. Instead of using it, let's split the rain set.
splits = torch.utils.data.random_split(all_train_data, [0.95, 0.05])

train_set_processed = AugmentedDataset(splits[0], all_train_data.num_labels, None)
validation_set_processed = AugmentedDataset(
splits[1], all_train_data.num_labels, None
)
train_set_processed = AugmentedDataset(splits[0], num_labels, None)
validation_set_processed = AugmentedDataset(splits[1], num_labels, None)

"""
validation_set_processed = AdobeVFRDataset(
f"{train_config['data_root']}/VFR_syn_val",
f"{train_config['data_root']}/BCF format/VFR_syn_val",
"val",
get_random_square_patch_augmentation(),
)
"""
test_set_processed = (
AdobeVFRDataset(
f"{train_config['data_root']}/VFR_real_test",
f"{train_config['data_root']}/BCF format/VFR_real_test",
"vfr_large",
get_random_square_patch_augmentation(),
)
Expand All @@ -123,7 +163,7 @@ def load_adobevfr_dataset(train_config):
)

return (
all_train_data.num_labels,
num_labels,
train_set_processed,
validation_set_processed,
test_set_processed,
Expand Down

0 comments on commit 0ed66e4

Please sign in to comment.