Skip to content

Latest commit

 

History

History
86 lines (54 loc) · 5.06 KB

HOWTOs.md

File metadata and controls

86 lines (54 loc) · 5.06 KB

HOWTOs

English | 简体中文

How to train StyleGAN2

  1. Prepare training dataset: FFHQ. More details are in DatasetPreparation.md

    1. Download FFHQ dataset. Recommend to download the tfrecords files from NVlabs/ffhq-dataset.

    2. Extract tfrecords to images or LMDBs (TensorFlow is required to read tfrecords):

      python scripts/data_preparation/extract_images_from_tfrecords.py

  2. Modify the config file in options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ.yml

  3. Train with distributed training. More training commands are in TrainTest.md.

    python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/StyleGAN/train_StyleGAN2_256_Cmul2_FFHQ_800k.yml --launcher pytorch

How to inference StyleGAN2

  1. Download pre-trained models from ModelZoo (Google Drive, 百度网盘) to the experiments/pretrained_models folder.

  2. Test.

    python inference/inference_stylegan2.py

  3. The results are in the samples folder.

How to inference DFDNet

  1. Install dlib, because DFDNet uses dlib to do face recognition and landmark detection. Installation reference.

    1. Clone dlib repo: git clone [email protected]:davisking/dlib.git
    2. cd dlib
    3. Install: python setup.py install
  2. Download the dlib pretrained models from ModelZoo (Google Drive, 百度网盘) to the experiments/pretrained_models/dlib folder.
    You can download by run the following command OR manually download the pretrained models.

    python scripts/download_pretrained_models.py dlib

  3. Download pretrained DFDNet models, dictionary and face template from ModelZoo (Google Drive, 百度网盘) to the experiments/pretrained_models/DFDNet folder.
    You can download by run the the following command OR manually download the pretrained models.

    python scripts/download_pretrained_models.py DFDNet

  4. Prepare the testing dataset in the datasets, for example, we put images in the datasets/TestWhole folder.

  5. Test.

    python inference/inference_dfdnet.py --upscale_factor=2 --test_path datasets/TestWhole

  6. The results are in the results/DFDNet folder.

How to train SwinIR (SR)

We take the classical SR X4 with DIV2K for example.

  1. Prepare the training dataset: DIV2K. More details are in DatasetPreparation.md

  2. Prepare the validation dataset: Set5. You can download with this guidance

  3. Modify the config file in options/train/SwinIR/train_SwinIR_SRx4_scratch.yml accordingly.

  4. Train with distributed training. More training commands are in TrainTest.md.

    python -m torch.distributed.launch --nproc_per_node=8 --master_port=4331 basicsr/train.py -opt options/train/SwinIR/train_SwinIR_SRx4_scratch.yml --launcher pytorch --auto_resume

Note that:

  1. Different from the original setting in the paper where the X4 model is finetuned from the X2 model, we directly train it from scratch.
  2. We also use EMA (Exponential Moving Average). Note that all model trainings in BasicSR supports EMA.
  3. In the 250K iteration of training X4 model, it can achieve comparable performance to the official model.
ClassicalSR DIV2KX4 PSNR (RGB) PSNR (Y) SSIM (RGB) SSIM (Y)
Official 30.803 32.728 0.8738 0.9028
Reproduce 30.832 32.756 0.8739 0.9025

How to inference SwinIR (SR)

  1. Download pre-trained models from the official SwinIR repo to the experiments/pretrained_models/SwinIR folder.

  2. Inference.

    python inference/inference_swinir.py --input datasets/Set5/LRbicx4 --patch_size 48 --model_path experiments/pretrained_models/SwinIR/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth --output results/SwinIR_SRX4_DIV2K/Set5

  3. The results are in the results/SwinIR_SRX4_DIV2K/Set5 folder.

  4. You may want to calculate the PSNR/SSIM values.

    python scripts/metrics/calculate_psnr_ssim.py --gt datasets/Set5/GTmod12/ --restored results/SwinIR_SRX4_DIV2K/Set5 --crop_border 4

    or test with the Y channel with the --test_y_channel argument.

    python scripts/metrics/calculate_psnr_ssim.py --gt datasets/Set5/GTmod12/ --restored results/SwinIR_SRX4_DIV2K/Set5 --crop_border 4 --test_y_channel