From 0ed66e4b077e54e31aeda44d679a6ce387b62156 Mon Sep 17 00:00:00 2001 From: Alessio Placitelli Date: Sat, 7 Oct 2023 08:13:55 +0200 Subject: [PATCH] Use real and synt data for AdobeVFR autoencoder train The paper mentions that, for training the autoencoder, both real and synthetic data are used and that this improves the performance of the final inference. Let's do the same here. --- src/fontina/train.py | 60 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/src/fontina/train.py b/src/fontina/train.py index 570fbd0..7ce54bd 100644 --- a/src/fontina/train.py +++ b/src/fontina/train.py @@ -8,8 +8,6 @@ ) import torch -from torchvision import datasets, transforms -from torch.utils.data import DataLoader from fontina.adobevfr_dataset import AdobeVFRDataset from fontina.augmentation_utils import ( get_deepfont_full_augmentations, @@ -22,6 +20,18 @@ from fontina.models.lightning_generate_callback import GenerateCallback from fontina.models.lightning_wrappers import DeepFontAutoencoderWrapper, DeepFontWrapper +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + +# The torchvision `DataLoader` uses Pillow under the hood to +# load image files. However PIL will fail to load some PNG +# files by throwing a zlib decompression error: +# "ValueError: Decompressed Data Too Large". Setting +# `LOAD_TRUNCATED_IMAGES = True` mitigates this problem. +from PIL import Image, ImageFile, UnidentifiedImageError + +ImageFile.LOAD_TRUNCATED_IMAGES = True + def get_parser(): parser = argparse.ArgumentParser(description="Fontina training") @@ -90,31 +100,61 @@ def load_and_split_data(train_config): def load_adobevfr_dataset(train_config): all_train_data = AdobeVFRDataset( - f"{train_config['data_root']}/VFR_syn_train", + f"{train_config['data_root']}/BCF format/VFR_syn_train", "train", get_deepfont_full_augmentations(), ) + num_labels = all_train_data.num_labels + + # From the DeepFont paper: "We first train the SCAE on both synthetic and + # real-world data in a unsupervised way [...]". When training the autoencoder, + # we merge the real and synthetic datasets for training purposes. + if train_config["only_autoencoder"]: + + def is_valid_image(path): + try: + _ = Image.open(path) + return True + except (UnidentifiedImageError, ValueError): + print(f"Failed to load image: {path}") + return False + + real_data = datasets.ImageFolder( + root=f"{train_config['data_root']}/Raw Image/VFR_real_u", + # Important: albumentation can't set grayscale and output only one + # channel, so do it here. + transform=transforms.Grayscale(num_output_channels=1), + target_transform=None, + is_valid_file=is_valid_image, + ) + real_data_processed = AugmentedDataset( + real_data, 1, get_deepfont_full_augmentations() + ) + + # Override 'all_train_data' with the joined dataset. + all_train_data = torch.utils.data.ConcatDataset( + [all_train_data, real_data_processed] + ) + # Although the AdobeVFR dataset readme says that VFR_syn_val contains # the validation for the same classes as VFR_syn_train, that doens't # seem to be the case: the former contains 2383 classes, the latter # 4383. Instead of using it, let's split the rain set. splits = torch.utils.data.random_split(all_train_data, [0.95, 0.05]) - train_set_processed = AugmentedDataset(splits[0], all_train_data.num_labels, None) - validation_set_processed = AugmentedDataset( - splits[1], all_train_data.num_labels, None - ) + train_set_processed = AugmentedDataset(splits[0], num_labels, None) + validation_set_processed = AugmentedDataset(splits[1], num_labels, None) """ validation_set_processed = AdobeVFRDataset( - f"{train_config['data_root']}/VFR_syn_val", + f"{train_config['data_root']}/BCF format/VFR_syn_val", "val", get_random_square_patch_augmentation(), ) """ test_set_processed = ( AdobeVFRDataset( - f"{train_config['data_root']}/VFR_real_test", + f"{train_config['data_root']}/BCF format/VFR_real_test", "vfr_large", get_random_square_patch_augmentation(), ) @@ -123,7 +163,7 @@ def load_adobevfr_dataset(train_config): ) return ( - all_train_data.num_labels, + num_labels, train_set_processed, validation_set_processed, test_set_processed,