Skip to content

Commit

Permalink
update README
Browse files Browse the repository at this point in the history
  • Loading branch information
tqch committed Sep 18, 2022
1 parent a349887 commit 7844df6
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 26 deletions.
71 changes: 45 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# PyTorch Implementation of Denoising Diffusion Probabilistic Models [[paper]](https://arxiv.org/abs/2006.11239) [[official repo]](https://github.com/hojonathanho/diffusion)

## Code usage

### Toy data
```shell

<details>
<summary>Expand</summary>
<pre><code>
usage: train_toy.py [-h] [--dataset {gaussian8,gaussian25,swissroll}]
[--size SIZE] [--root ROOT] [--epochs EPOCHS] [--lr LR]
[--beta1 BETA1] [--beta2 BETA2] [--lr-warmup LR_WARMUP]
Expand All @@ -16,7 +20,6 @@ usage: train_toy.py [-h] [--dataset {gaussian8,gaussian25,swissroll}]
[--eval-intv EVAL_INTV] [--seed SEED] [--resume]
[--gpu GPU] [--mid-features MID_FEATURES]
[--num-temporal-layers NUM_TEMPORAL_LAYERS]

optional arguments:
-h, --help show this help message and exit
--dataset {gaussian8,gaussian25,swissroll}
Expand Down Expand Up @@ -47,28 +50,28 @@ optional arguments:
--gpu GPU
--mid-features MID_FEATURES
--num-temporal-layers NUM_TEMPORAL_LAYERS
```
</code></pre>
</details>

### Real-world data

```shell
usage: train.py [-h] [--model {unet}] [--dataset {mnist,cifar10,celeba}]
[--root ROOT] [--epochs EPOCHS] [--lr LR] [--beta1 BETA1]
[--beta2 BETA2] [--batch-size BATCH_SIZE]
[--timesteps TIMESTEPS]
<details><summary>Expand</summary>
<pre><code>
usage: train.py [-h] [--dataset {mnist,cifar10,celeba}] [--root ROOT]
[--epochs EPOCHS] [--lr LR] [--beta1 BETA1] [--beta2 BETA2]
[--batch-size BATCH_SIZE] [--timesteps TIMESTEPS]
[--beta-schedule {quad,linear,warmup10,warmup50,jsd}]
[--beta-start BETA_START] [--beta-end BETA_END]
[--model-mean-type {mean,x_0,eps}]
[--model-var-type {learned,fixed-small,fixed-large}]
[--loss-type {kl,mse}] [--task {generation}]
[--loss-type {kl,mse}] [--num-workers NUM_WORKERS]
[--train-device TRAIN_DEVICE] [--eval-device EVAL_DEVICE]
[--image-dir IMAGE_DIR] [--num-save-images NUM_SAVE_IMAGES]
[--config-dir CONFIG_DIR] [--chkpt-dir CHKPT_DIR]
[--chkpt-intv CHKPT_INTV] [--log-dir LOG_DIR] [--seed SEED]
[--resume] [--eval] [--use-ema] [--ema-decay EMA_DECAY]

[--chkpt-intv CHKPT_INTV] [--seed SEED] [--resume] [--eval]
[--use-ema] [--ema-decay EMA_DECAY] [--distributed]
optional arguments:
-h, --help show this help message and exit
--model {unet} backbone decoder
--dataset {mnist,cifar10,celeba}
--root ROOT root directory of datasets
--epochs EPOCHS total number of training epochs
Expand All @@ -84,7 +87,8 @@ optional arguments:
--model-mean-type {mean,x_0,eps}
--model-var-type {learned,fixed-small,fixed-large}
--loss-type {kl,mse}
--task {generation}
--num-workers NUM_WORKERS
number of workers for data loading
--train-device TRAIN_DEVICE
--eval-device EVAL_DEVICE
--image-dir IMAGE_DIR
Expand All @@ -94,22 +98,26 @@ optional arguments:
--chkpt-dir CHKPT_DIR
--chkpt-intv CHKPT_INTV
frequency of saving a checkpoint
--log-dir LOG_DIR
--seed SEED random seed
--resume to resume from a checkpoint
--eval whether to evaluate fid during training
--use-ema whether to use exponential moving average
--ema-decay EMA_DECAY
decay factor of ema
```
--distributed whether to use distributed training
</code></pre>
</details>

### Examples
```shell
# train a 25-Gaussian toy model on cuda:0 a total of 100 epochs
# train a 25-Gaussian toy model on cuda:0 for a total of 100 epochs
python train_toy.py --dataset gaussian8 --gpu 0 --epochs 100

# train a cifar10 model on cuda:0 for a total of 50 epochs
# train a cifar10 model with single gpu for a total of 50 epochs
python train.py --dataset cifar10 --train-device cuda:0 --epochs 50

# train a celeba model with 2 gpus and an effective batch-size of 64 x 2 = 128
export CUDA_VISIBLE_DEVICES=0,1&&torchrun --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset celeba --use-ema --distributed
```

## Experiment results
Expand All @@ -127,21 +135,33 @@ python train.py --dataset cifar10 --train-device cuda:0 --epochs 50

### Real-world data

*Table of evaluation metrics*

|Dataset|FID (↓)|Precision (↑)|Recall (↑)|Training steps|
|:---:|:---:|:---:|:---:|:---:|
|CIFAR-10|11.11|0.738|0.421|46.8k|
|\|__|6.45|0.727|0.480|93.6k|
|\|__|4.99|0.727|0.503|140.4k|
|\|__|4.48|0.730|0.517|187.2k|
|\|__|4.07|0.731|**0.524**|234.0k|
|\|__|**4.01**|**0.733**|0.520|280.8k|
|CelebA|4.45|0.778|0.478|237.3k|

#### CIFAR-10

##### Training samples (100 epochs)
<p align="center"> <img alt="cifar10_train_100" src="./assets/cifar10_train_100.gif" /> </p>
##### Training samples (720 epochs)
<p align="center"> <img alt="cifar10_train_100" src="./assets/cifar10_train_720.webp" /> </p>

##### Denoising process
<p align="center"> <img alt="cifar10_denoise_100" src="./assets/cifar10_denoise_100.gif" /> </p>
<p align="center"> <img alt="cifar10_denoise_100" src="./assets/cifar10_denoise_100.webp" /></p>

#### Celeb-A
#### CelebA

##### Training samples (100 epochs)
<p align="center"> <img alt="celeba_train_100" src="./assets/celeba_train_100.gif" /> </p>
##### Training samples (150 epochs)
<p align="center"> <img alt="celeba_train_100" src="./assets/celeba_train_150.webp" /> </p>

##### Denoising process
<p align="center"> <img alt="celeba_denoise_100" src="./assets/celeba_denoise_100.gif" /> </p>
<p align="center"> <img alt="celeba_denoise_100" src="./assets/celeba_denoise_100.webp" /> </p>

## Reference formulae

Expand All @@ -159,4 +179,3 @@ $$ x\_{t-1} \mid x\_t, x\_0 \sim \text{N}\left(\frac{1}{\sqrt{\bar{\alpha}\_t}}\
where

$$\sigma\_t^2 = \frac{\beta\_t(1-\bar{\alpha}\_{t-1})}{1-\bar{\alpha}\_t}$$

Binary file removed assets/celeba_denoise_100.gif
Binary file not shown.
Binary file added assets/celeba_denoise_100.webp
Binary file not shown.
Binary file removed assets/celeba_train_100.gif
Binary file not shown.
Binary file added assets/celeba_train_150.webp
Binary file not shown.
Binary file removed assets/cifar10_denoise_100.gif
Binary file not shown.
Binary file added assets/cifar10_denoise_100.webp
Binary file not shown.
Binary file removed assets/cifar10_train_100.gif
Binary file not shown.
Binary file added assets/cifar10_train_720.webp
Binary file not shown.

0 comments on commit 7844df6

Please sign in to comment.