Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
tqch committed Feb 9, 2023
1 parent 86c0f3e commit 91ef3f6
Show file tree
Hide file tree
Showing 11 changed files with 227 additions and 189 deletions.
2 changes: 1 addition & 1 deletion ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def p_sample(self, denoise_fn, shape, device=torch.device("cpu"), noise=None):
B, *_ = shape
subsequence = self.subsequence.to(device)
_denoise_fn = lambda x, t: denoise_fn(x, subsequence.gather(0, t))
t = torch.empty((B,), dtype=torch.int64, device=device)
t = torch.empty((B, ), dtype=torch.int64, device=device)
if noise is None:
x_t = torch.randn(shape, device=device)
else:
Expand Down
9 changes: 5 additions & 4 deletions ddpm_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .datasets import get_dataloader, DATA_INFO
from .utils import seed_all, get_param, Configs
from .datasets import get_dataloader, DATASET_DICT, DATASET_INFO
from .utils import seed_all, get_param, ConfigDict
from .utils.train import Trainer, DummyScheduler, ModelWrapper
from .metrics import Evaluator
from .diffusion import GaussianDiffusion, get_beta_schedule
Expand All @@ -8,10 +8,11 @@

__all__ = [
"get_dataloader",
"DATA_INFO",
"DATASET_DICT",
"DATASET_INFO",
"seed_all",
"get_param",
"Configs",
"ConfigDict",
"Trainer",
"DummyScheduler",
"ModelWrapper",
Expand Down
205 changes: 113 additions & 92 deletions ddpm_torch/datasets.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,94 @@
import re
import os
import csv
import PIL
import torch
import numpy as np
from torchvision import transforms, datasets
from torchvision import transforms, datasets as tvds
from torch.utils.data import DataLoader, Subset, Sampler
from torch.utils.data.distributed import DistributedSampler
from collections import namedtuple

CSV = namedtuple("CSV", ["header", "index", "data"])
CONDITIONAL = False
DATASET_DICT = dict()
DATASET_INFO = dict()


def register_dataset(cls):
name = cls.__name__.lower()
DATASET_DICT[name] = cls
info = dict()
for k, v in cls.__dict__.items():
if re.match(r"__\w+__", k) is None and not callable(v):
info[k] = v
DATASET_INFO[name] = info
return cls


@register_dataset
class MNIST(tvds.MNIST):
resolution = (32, 32)
channels = 1
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_size = 60000
test_size = 10000


@register_dataset
class CIFAR10(tvds.CIFAR10):
resolution = (32, 32)
channels = 3
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
_transform = transforms.PILToTensor()
train_size = 50000
test_size = 10000


def crop_celeba(img):
return transforms.functional.crop(img, top=40, left=15, height=148, width=148) # noqa


class CelebA(datasets.VisionDataset):
@register_dataset
class CelebA(tvds.VisionDataset):
"""
Large-scale CelebFaces Attributes (CelebA) Dataset <https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>
"""
base_folder = "celeba"
resolution = (64, 64)
channels = 3
transform = transforms.Compose([
crop_celeba,
transforms.Resize((64, 64)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
_transform = transforms.Compose([
crop_celeba,
transforms.Resize((64, 64)),
transforms.PILToTensor()
])
all_size = 202599
train_size = 162770
test_size = 19962
valid_size = 19867

def __init__(
self,
root,
split,
transform=transforms.ToTensor()
transform=None
):
super().__init__(root, transform=transform)
super().__init__(root, transform=transform or self._transform)
self.split = split
split_map = {
"train": 0,
Expand Down Expand Up @@ -65,13 +125,13 @@ def _load_csv(
return CSV(headers, indices, torch.as_tensor(data_int))

def __getitem__(self, index):
X = PIL.Image.open(os.path.join( # noqa
im = PIL.Image.open(os.path.join( # noqa
self.root, self.base_folder, "img_align_celeba", self.filename[index]))

if self.transform is not None:
X = self.transform(X)
im = self.transform(im)

return X
return im

def __len__(self):
return len(self.filename)
Expand All @@ -81,20 +141,30 @@ def extra_repr(self):
return "\n".join(lines).format(**self.__dict__)


class CelebAHQ(datasets.VisionDataset):
@register_dataset
class CelebAHQ(tvds.VisionDataset):
"""
High-Quality version of the CELEBA dataset, consisting of 30000 images in 1024 x 1024 resolution
created by Karras et al. (2018) [1]
[1] Karras, Tero, et al. "Progressive Growing of GANs for Improved Quality, Stability, and Variation." International Conference on Learning Representations. 2018.
""" # noqa
base_folder = "celeba_hq"
resolution = (256, 256)
channels = 3
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
_transform = transforms.PILToTensor()
all_size = 30000

def __init__(
self,
root,
transform=transforms.ToTensor()
transform=None
):
super().__init__(root, transform=transform)
super().__init__(root, transform=transform or self._transform)
self.filename = sorted([
fname
for fname in os.listdir(os.path.join(root, self.base_folder, "img_celeba_hq"))
Expand All @@ -103,85 +173,23 @@ def __init__(
np.random.RandomState(123).shuffle(self.filename) # legacy order used by ProGAN

def __getitem__(self, index):
X = PIL.Image.open(os.path.join( # noqa
im = PIL.Image.open(os.path.join( # noqa
self.root, self.base_folder, "img_celeba_hq", self.filename[index]))

if self.transform is not None:
X = self.transform(X)
im = self.transform(im)

return X
return im

def __len__(self):
return len(self.filename)


DATA_INFO = {
"mnist": {
"data": datasets.MNIST,
"resolution": (32, 32),
"channels": 1,
"transform": transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))
]),
"train_size": 60000,
"test_size": 10000
},
"cifar10": {
"data": datasets.CIFAR10,
"resolution": (32, 32),
"channels": 3,
"transform": transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
]),
"_transform": transforms.PILToTensor(),
"train_size": 50000,
"test_size": 10000
},
"celeba": {
"data": datasets.CelebA if CONDITIONAL else CelebA,
"resolution": (64, 64),
"channels": 3,
"transform": transforms.Compose([
crop_celeba,
transforms.Resize((64, 64)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]),
"_transform": transforms.Compose([
crop_celeba,
transforms.Resize((64, 64)),
transforms.PILToTensor()
]),
"all": 202599,
"train": 162770,
"test": 19962,
"validation": 19867
},
"celebahq": {
"data": CelebAHQ,
"resolution": (256, 256),
"channels": 3,
"transform": transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]),
"_transform": transforms.PILToTensor(),
"all": 30000
}
}

ROOT = os.path.expanduser("~/datasets")


def train_val_split(dataset, val_size, random_seed=None):
train_size = DATA_INFO[dataset]["train_size"]
train_size = DATASET_INFO[dataset]["train_size"]
if random_seed is not None:
np.random.seed(random_seed)
train_inds = np.arange(train_size)
Expand Down Expand Up @@ -215,30 +223,43 @@ def get_dataloader(
distributed=False
):
assert isinstance(val_size, float) and 0 <= val_size < 1
transform = DATA_INFO[dataset]["transform"]

name, dataset = dataset, DATASET_DICT[dataset]
transform = dataset.transform
if distributed:
batch_size = batch_size // int(os.environ.get("WORLD_SIZE", "1"))
dataloader_configs = {
"batch_size": batch_size,
"pin_memory": pin_memory,
"drop_last": drop_last,
"num_workers": num_workers
}

data_kwargs = {"root": root, "transform": transform}
if dataset == "celeba":
if name == "celeba":
data_kwargs["split"] = split
elif dataset in {"mnist", "cifar10"}:
elif name in {"mnist", "cifar10"}:
data_kwargs["download"] = False
data_kwargs["train"] = split != "test"
data = DATA_INFO[dataset]["data"](**data_kwargs)
dataset = dataset(**data_kwargs)

if data_kwargs.get("train", False) and val_size > 0.:
train_inds, val_inds = train_val_split(dataset, val_size, random_seed)
data = Subset(data, {"train": train_inds, "valid": val_inds}[split])
train_inds, val_inds = train_val_split(name, val_size, random_seed)
dataset = Subset(dataset, {"train": train_inds, "valid": val_inds}[split])

dataloader_configs = {
"batch_size": batch_size,
"pin_memory": pin_memory,
"drop_last": drop_last,
"num_workers": num_workers
}
dataloader_configs["sampler"] = sampler = DistributedSampler(
data, shuffle=True, seed=random_seed, drop_last=drop_last) if distributed else None
dataset, shuffle=True, seed=random_seed, drop_last=drop_last) if distributed else None
dataloader_configs["shuffle"] = (sampler is None) if split in {"train", "all"} else False
dataloader = DataLoader(data, **dataloader_configs)
dataloader = DataLoader(dataset, **dataloader_configs)
return dataloader, sampler


if __name__ == "__main__":
try:
from .utils import dict2str
except ImportError:
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[0]))
from utils import dict2str

print(dict2str(DATASET_INFO, compact=False))
2 changes: 1 addition & 1 deletion ddpm_torch/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(

alphas = 1 - betas
self.alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = torch.cat([torch.ones(1, dtype=torch.float64), self.alphas_bar[:-1]])
alphas_bar_prev = torch.cat([torch.as_tensor([1., ], dtype=torch.float64), self.alphas_bar[:-1]])

# q(x_t | x_0)
self.sqrt_alphas_bar = torch.sqrt(self.alphas_bar)
Expand Down
4 changes: 2 additions & 2 deletions ddpm_torch/metrics/fid_score.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This script is modified from the original PyTorch Implementation of FID
(https://github.com/mseitzer/pytorch-fid/) to support fid evaulation on
the fly during the training process without writing data onto the disk.
(https://github.com/mseitzer/pytorch-fid/) to support fid evaluation on
the fly without writing data onto the disk during the training process.
"""

"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
Expand Down
2 changes: 1 addition & 1 deletion ddpm_torch/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
in_channels,
out_channels,
embed_dim,
drop_rate=0.5
drop_rate=0.
):
super(ResidualBlock, self).__init__()
self.norm1 = self.normalize(in_channels)
Expand Down
Loading

0 comments on commit 91ef3f6

Please sign in to comment.