Skip to content

Commit

Permalink
fix --eval flag in train.py and related scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
tqch committed Mar 1, 2023
1 parent 91ef3f6 commit db49df6
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 63 deletions.
2 changes: 1 addition & 1 deletion configs/celeba.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"lr": 2e-5,
"batch_size": 128,
"grad_norm": 1.0,
"epochs": 100,
"epochs": 600,
"warmup": 5000
},
"denoise": {
Expand Down
2 changes: 1 addition & 1 deletion configs/celebahq.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"lr": 2e-5,
"batch_size": 64,
"grad_norm": 1.0,
"epochs": 100,
"epochs": 600,
"warmup": 5000
},
"denoise": {
Expand Down
2 changes: 1 addition & 1 deletion configs/cifar10.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"lr": 2e-4,
"batch_size": 128,
"grad_norm": 1.0,
"epochs": 100,
"epochs": 2160,
"warmup": 5000
},
"denoise": {
Expand Down
25 changes: 17 additions & 8 deletions ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import ddpm_torch


__all__ = ["get_selection_schedule", "DDIM"]


# def get_selection_schedule(schedule, size, timesteps):
# """
# :param schedule: selection schedule
Expand Down Expand Up @@ -36,7 +39,7 @@ def get_selection_schedule(schedule, size, timesteps):
if schedule == "linear":
subsequence = torch.arange(0, timesteps, timesteps // size)
else:
subsequence = torch.pow(torch.linspace(0, math.sqrt(timesteps * 0.8), size), 2).round().to(torch.int32) # noqa
subsequence = torch.pow(torch.linspace(0, math.sqrt(timesteps * 0.8), size), 2).round().to(torch.int64) # noqa

return subsequence

Expand All @@ -48,8 +51,12 @@ def __init__(self, betas, model_mean_type, model_var_type, loss_type, eta, subse
self.subsequence = subsequence # subsequence of the accelerated generation

eta2 = eta ** 2
assert not (eta2 != 1. and model_var_type != "fixed-small"), \
'Cannot use DDIM (eta < 1) with var type other than "fixed-small"'
try:
assert not (eta2 != 1. and model_var_type != "fixed-small"), \
"Cannot use DDIM (eta < 1) with var type other than `fixed-small`"
except AssertionError:
# Automatically convert model_var_type to `fixed-small`
self.model_var_type = "fixed-small"

self.alphas_bar = self.alphas_bar[subsequence]
self.alphas_bar_prev = torch.cat([torch.ones(1, dtype=torch.float64), self.alphas_bar[:-1]], dim=0)
Expand Down Expand Up @@ -80,27 +87,29 @@ def __init__(self, betas, model_mean_type, model_var_type, loss_type, eta, subse
# for fixed model_var_type's
self.fixed_model_var, self.fixed_model_logvar = {
"fixed-large": (
self.betas, torch.log(
torch.cat([self.posterior_var[[1]], self.betas[1:]]).clip(min=1e-20))),
self.betas, torch.log(torch.cat([self.posterior_var[[1]], self.betas[1:]]).clip(min=1e-20))),
"fixed-small": (self.posterior_var, self.posterior_logvar_clipped)
}[self.model_var_type]

self.subsequence = torch.as_tensor(subsequence)

@torch.inference_mode()
def p_sample(self, denoise_fn, shape, device=torch.device("cpu"), noise=None):
def p_sample(self, denoise_fn, shape, device=torch.device("cpu"), noise=None, seed=None):
S = len(self.subsequence)
B, *_ = shape
subsequence = self.subsequence.to(device)
_denoise_fn = lambda x, t: denoise_fn(x, subsequence.gather(0, t))
t = torch.empty((B, ), dtype=torch.int64, device=device)
rng = None
if seed is not None:
rng = torch.Generator(device).manual_seed(seed)
if noise is None:
x_t = torch.randn(shape, device=device)
x_t = torch.empty(shape, device=device).normal_(generator=rng)
else:
x_t = noise.to(device)
for ti in range(S - 1, -1, -1):
t.fill_(ti)
x_t = self.p_sample_step(_denoise_fn, x_t, t)
x_t = self.p_sample_step(_denoise_fn, x_t, t, generator=rng)
return x_t

@staticmethod
Expand Down
34 changes: 24 additions & 10 deletions ddpm_torch/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
import torch
from tqdm import trange
from .fid_score import InceptionStatistics, get_precomputed, calc_fd
from .precision_recall import ManifoldBuilder, Manifold, calc_pr
import torch

__all__ = [
"InceptionStatistics",
Expand All @@ -19,21 +21,33 @@ def __init__(
dataset,
diffusion=None,
eval_batch_size=256,
max_eval_count=10000,
eval_total_size=50000,
device=torch.device("cpu")
):
self.diffusion = diffusion
# inception stats
self.istats = InceptionStatistics(device=device)
self.eval_batch_size = eval_batch_size
self.max_eval_count = max_eval_count
self.eval_total_size = eval_total_size
self.device = device
self.target_mean, self.target_var = get_precomputed(dataset)

def eval(self, sample_fn):
self.istats.reset()
for _ in range(0, self.max_eval_count + self.eval_batch_size, self.eval_batch_size):
x = sample_fn(self.eval_batch_size, diffusion=self.diffusion)
self.istats(x.to(self.device))
gen_mean, gen_var = self.istats.get_statistics()
return {"fid": calc_fd(gen_mean, gen_var, self.target_mean, self.target_var)}
def eval(self, sample_fn, is_leader=True):
if is_leader:
self.istats.reset()
fid = None
num_batches = math.ceil(self.eval_total_size / self.eval_batch_size)
with trange(num_batches, desc="Evaluating FID", disable=not is_leader) as t:
for i in t:
if i == len(t) - 1:
batch_size = self.eval_total_size % self.eval_batch_size
else:
batch_size = self.eval_batch_size
x = sample_fn(sample_size=batch_size, diffusion=self.diffusion)
if is_leader:
self.istats(x.to(self.device))
if i == len(t) - 1:
gen_mean, gen_var = self.istats.get_statistics()
fid = calc_fd(gen_mean, gen_var, self.target_mean, self.target_var)
t.set_postfix({"fid": fid})
return {"fid": fid}
3 changes: 1 addition & 2 deletions ddpm_torch/metrics/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
class VGGFeatureExtractor:
WEIGHTS_URL = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt"

def __init__(self, device=torch.device("cpu")):
def __init__(self):
self.model = self._load_model()
self.device = device

def _load_model(self):
model_path = os.path.join(get_dir(), os.path.basename(self.WEIGHTS_URL))
Expand Down
63 changes: 29 additions & 34 deletions ddpm_torch/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,29 +155,29 @@ def step(self, x, global_steps=1):
loss.div_(self.world_size)
self.stats.update(x.shape[0], loss=loss.item() * x.shape[0])

def sample_fn(self, noise, diffusion=None):
def sample_fn(self, sample_size=None, noise=None, diffusion=None, sample_seed=None):
if noise is None:
shape = (sample_size // self.world_size, ) + self.shape
else:
shape = noise.shape
if diffusion is None:
diffusion = self.diffusion
shape = noise.shape
with self.ema:
sample = diffusion.p_sample(
denoise_fn=self.model, shape=shape,
device=self.device, noise=noise, seed=self.sample_seed)
device=self.device, noise=noise, seed=sample_seed)
if self.distributed:
# balance GPU memory usages within the same process group
sample_list = [torch.zeros(shape, device=self.device) for _ in range(self.world_size)]
dist.all_gather(sample_list, sample)
sample = torch.cat(sample_list, dim=0)
assert sample.grad is None
return sample

def train(self, evaluator=None, chkpt_path=None, image_dir=None):

num_samples = self.num_save_images
nrow = math.floor(math.sqrt(num_samples))
if num_samples:
assert num_samples % self.world_size == 0, "Number of samples should be divisible by WORLD_SIZE!"
shape = (num_samples // self.world_size, ) + self.shape
# fix random number generator for sampling
rng = torch.Generator().manual_seed(self.sample_seed)
noise = torch.empty(shape).normal_(generator=rng)
else:
shape, noise = None, None
nrow = math.floor(math.sqrt(self.num_save_images))
if self.num_save_images:
assert self.num_save_images % self.world_size == 0, "Number of samples should be divisible by WORLD_SIZE!"

if self.dry_run:
self.start_epoch, self.epochs = 0, 1
Expand All @@ -199,28 +199,23 @@ def train(self, evaluator=None, chkpt_path=None, image_dir=None):
results.update(self.current_stats)
if self.dry_run and not global_steps % self.num_accum:
break
if i == len(self.trainloader) - 1:
self.model.eval()
if evaluator is not None:
eval_results = evaluator.eval(self.sample_fn)
else:
eval_results = dict()
results.update(eval_results)
t.set_postfix(results)

if not (e + 1) % self.image_intv and num_samples and image_dir:

if not (e + 1) % self.image_intv and self.num_save_images and image_dir:
self.model.eval()
x = self.sample_fn(noise)
if self.distributed:
# balance GPU memory usages within the same process group
x_list = [torch.zeros(shape, device=self.device) for _ in range(self.world_size)]
dist.all_gather(x_list, x)
x = torch.cat(x_list, dim=0)
x = x.cpu()
x = self.sample_fn(sample_size=self.num_save_images, sample_seed=self.sample_seed).cpu()
if self.is_leader:
save_image(x.cpu(), os.path.join(image_dir, f"{e + 1}.jpg"), nrow=nrow)
if not (e + 1) % self.chkpt_intv and chkpt_path and self.is_leader:
self.save_checkpoint(chkpt_path, epoch=e+1, **results)
save_image(x, os.path.join(image_dir, f"{e + 1}.jpg"), nrow=nrow)

if not (e + 1) % self.chkpt_intv and chkpt_path:
self.model.eval()
if evaluator is not None:
eval_results = evaluator.eval(self.sample_fn, is_leader=self.is_leader)
else:
eval_results = dict()
results.update(eval_results)
if self.is_leader:
self.save_checkpoint(chkpt_path, epoch=e+1, **results)

if self.distributed:
dist.barrier() # synchronize all processes here

Expand Down
2 changes: 0 additions & 2 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

parser.add_argument("--root", default="~/datasets", type=str)
parser.add_argument("--dataset", choices=["mnist", "cifar10", "celeba", "celebahq"], default="cifar10")
parser.add_argument("--model-device", default=0, type=int)
parser.add_argument("--eval-device", default=0, type=int)
parser.add_argument("--eval-batch-size", default=512, type=int)
parser.add_argument("--eval-total-size", default=50000, type=int)
parser.add_argument("--num-workers", default=4, type=int)
Expand Down
2 changes: 1 addition & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def main():
parser.add_argument("--use-ddim", action="store_true")
parser.add_argument("--eta", default=0., type=float)
parser.add_argument("--skip-schedule", default="linear", type=str)
parser.add_argument("--subseq-size", default=10, type=int)
parser.add_argument("--subseq-size", default=50, type=int)
parser.add_argument("--suffix", default="", type=str)
parser.add_argument("--max-workers", default=8, type=int)
parser.add_argument("--num-gpus", default=1, type=int)
Expand Down
35 changes: 32 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
from datetime import datetime
from functools import partial
from ddim import *
from ddpm_torch import *
from torch.optim import Adam, lr_scheduler
import torch.distributed as dist
Expand Down Expand Up @@ -139,6 +140,11 @@ def logger(msg, **kwargs):
root=root, drop_last=True, pin_memory=True, num_workers=num_workers, distributed=distributed
) # drop_last to have a static input shape; num_workers > 0 to enable asynchronous data loading

if args.dry_run:
logger("This is a dry run.")
args.chkpt_intv = 1
args.image_intv = 1

chkpt_dir = args.chkpt_dir
chkpt_path = os.path.join(chkpt_dir, args.chkpt_name or f"ddpm_{dataset}.pt")
chkpt_intv = args.chkpt_intv
Expand Down Expand Up @@ -194,7 +200,25 @@ def logger(msg, **kwargs):
distributed=distributed,
dry_run=args.dry_run
)
evaluator = Evaluator(dataset=dataset, device=eval_device) if args.eval else None

if args.use_ddim:
subsequence = get_selection_schedule(
args.skip_schedule, size=args.subseq_size, timesteps=diffusion_configs.timesteps)
diffusion_eval = DDIM.from_ddpm(diffusion, eta=0., subsequence=subsequence)
else:
diffusion_eval = diffusion

if args.eval:
evaluator = Evaluator(
dataset=dataset,
diffusion=diffusion_eval,
eval_batch_size=args.eval_batch_size,
eval_total_size=args.eval_total_size,
device=eval_device
)
else:
evaluator = None

# in the case of distributed training, resume should always be turned on
resume = args.resume or distributed
if resume:
Expand Down Expand Up @@ -240,17 +264,22 @@ def main():
parser.add_argument("--train-device", default="cuda:0", type=str)
parser.add_argument("--eval-device", default="cuda:0", type=str)
parser.add_argument("--image-dir", default="./images/train", type=str)
parser.add_argument("--image-intv", default=1, type=int)
parser.add_argument("--image-intv", default=10, type=int)
parser.add_argument("--num-save-images", default=64, type=int, help="number of images to generate & save")
parser.add_argument("--config-dir", default="./configs", type=str)
parser.add_argument("--chkpt-dir", default="./chkpts", type=str)
parser.add_argument("--chkpt-name", default="", type=str)
parser.add_argument("--chkpt-intv", default=5, type=int, help="frequency of saving a checkpoint")
parser.add_argument("--chkpt-intv", default=100, type=int, help="frequency of saving a checkpoint")
parser.add_argument("--seed", default=1234, type=int, help="random seed")
parser.add_argument("--resume", action="store_true", help="to resume training from a checkpoint")
parser.add_argument("--chkpt-path", default="", type=str, help="checkpoint path used to resume training")
parser.add_argument("--eval", action="store_true", help="whether to evaluate fid during training")
parser.add_argument("--eval-total-size", default=50000, type=int)
parser.add_argument("--eval-batch-size", default=256, type=int)
parser.add_argument("--use-ema", action="store_true", help="whether to use exponential moving average")
parser.add_argument("--use-ddim", action="store_true", help="whether to use DDIM sampler for evaluation")
parser.add_argument("--skip-schedule", choices=["linear", "quadratic"], default="linear", type=str)
parser.add_argument("--subseq-size", default=50, type=int)
parser.add_argument("--ema-decay", default=0.9999, type=float, help="decay factor of ema")
parser.add_argument("--distributed", action="store_true", help="whether to use distributed training")
parser.add_argument("--rigid-launch", action="store_true", help="whether to use torch multiprocessing spawn")
Expand Down

0 comments on commit db49df6

Please sign in to comment.