diff --git a/README.md b/README.md index 1852509..02ddb4e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # GAN Compression ### [project](https://hanlab.mit.edu/projects/gancompression/) | [paper](https://arxiv.org/abs/2003.08936) | [demo](https://www.youtube.com/playlist?list=PL80kAHvQbh-r5R8UmXhQK1ndqRvPNw_ex) +**[NEW!]** The simplified pipeline of GAN Compression is released! Check the [tutorial](docs/simplified_pipeline.md) for the pipeline. + **[NEW!]** GauGAN training code and tutorial is released! Check the [tutorial](docs/training_tutorial.md) to compress GauGAN. **[NEW!]** Correct metric naming and update the evaluation protocol. Support MACs budget for searching. @@ -9,8 +11,6 @@ **[NEW!]** The [tutorial](docs/training_tutorial.md) of compression is released! Check the [overview](docs/overview.md) for better understanding our codebase. -**[NEW!]** The PyTorch implementation of a general conditional GAN Compression framework is released. - ![teaser](imgs/teaser.png)*We introduce GAN Compression, a general-purpose method for compressing conditional GANs. Our method reduces the computation of widely-used conditional GAN models, including pix2pix, CycleGAN, and GauGAN, by 9-21x while preserving the visual fidelity. Our method is effective for a wide range of generator architectures, learning objectives, and both paired and unpaired settings.* GAN Compression: Efficient Architectures for Interactive Conditional GANs
@@ -189,7 +189,7 @@ PyTorch Colab notebook: [CycleGAN](https://colab.research.google.com/github/mit- ### Cityscapes Dataset -For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script `datasets/prepare_cityscapes_dataset.py` to preprocess it. You need to download `gtFine_trainvaltest.zip` and `leftImg8bit_trainvaltest.zip` and unzip them in the same folder. For example, you may put `gtFine` and `leftImg8bit` in `database/cityscapes-origin`. You need to prepare the dataset with the following commands: +For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script [prepare_cityscapes_dataset.py](datasets/prepare_cityscapes_dataset.py) to preprocess it. You need to download `gtFine_trainvaltest.zip` and `leftImg8bit_trainvaltest.zip` and unzip them in the same folder. For example, you may put `gtFine` and `leftImg8bit` in `database/cityscapes-origin`. You need to prepare the dataset with the following commands: ```shell python datasets/get_trainIds.py database/cityscapes-origin/gtFine/ @@ -210,7 +210,7 @@ Please refer to our training [tutorial](docs/training_tutorial.md) on how to tra ### FID Computation -To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script `get_real_stat.py` to extract statistical information. For example, for the edges2shoes dataset, you could run the following command: +To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script [get_real_stat.py](./get_real_stat.py) to extract statistical information. For example, for the edges2shoes dataset, you could run the following command: ```shell python get_real_stat.py \ diff --git a/configs/resnet_configs.py b/configs/resnet_configs.py index d5d463a..8c91e2b 100644 --- a/configs/resnet_configs.py +++ b/configs/resnet_configs.py @@ -71,6 +71,22 @@ def get_configs(config_name): return ResnetConfigs(n_channels=[[32, 24, 16], [32, 24, 16], [32, 24, 16], [32, 24, 16], [32, 24, 16], [32, 24, 16], [32, 24, 16], [32, 24, 16]]) + elif config_name == 'channels-64-cycleGAN-stage1': + return ResnetConfigs(n_channels=[[64, 48, 32], [64, 48, 32], [64, 48, 32], + [64, 48, 32], [64, 48, 32], [64, 48, 32], + [64, 48, 32], [64, 48, 32]]) + elif config_name == 'channels-64-cycleGAN': + return ResnetConfigs(n_channels=[[64, 48, 32, 24, 16], [64, 48, 32, 24, 16], [64, 48, 32, 24, 16], + [64, 48, 32, 24, 16], [64, 48, 32, 24, 16], [64, 48, 32, 24, 16], + [64, 48, 32, 24, 16], [64, 48, 32, 24, 16]]) + elif config_name == 'channels-64-pix2pix-stage1': + return ResnetConfigs(n_channels=[[64, 48], [64, 48], [64, 48], + [64, 48], [64, 48], [64, 48], + [64, 48, 32], [64, 48, 32]]) + elif config_name == 'channels-64-pix2pix': + return ResnetConfigs(n_channels=[[64, 48, 32], [64, 48, 32], [64, 48, 40, 32], + [64, 48, 40, 32], [64, 48, 40, 32], [64, 48, 40, 32], + [64, 48, 32, 24, 16], [64, 48, 32, 24, 16]]) elif config_name == 'debug': return ResnetConfigs(n_channels=[[48, 32], [48, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], diff --git a/configs/spade_configs.py b/configs/spade_configs.py index 537fed0..988ff30 100644 --- a/configs/spade_configs.py +++ b/configs/spade_configs.py @@ -29,7 +29,7 @@ def smallest(self): ret['channels'].append(min(n_channel)) return ret - def all_configs(self, split=1, remainder=0): + def all_configs(self): def yield_channels(i): if i == len(self.n_channels): @@ -39,9 +39,8 @@ def yield_channels(i): for after_channels in yield_channels(i + 1): yield [n] + after_channels - for i, channels in enumerate(yield_channels(0)): - if i % split == remainder: - yield {'channels': channels} + for channels in yield_channels(0): + yield {'channels': channels} def __call__(self, name): assert name in ('largest', 'smallest') @@ -64,74 +63,6 @@ def get_configs(config_name): return SPADEConfigs(n_channels=[[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels-48-part1': - return SPADEConfigs(n_channels=[[48, 40, 32], - [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], - [48, 24], [48, 24], [48, 24]]) - elif config_name == 'channels-48-part2': - return SPADEConfigs(n_channels=[[48, 40, 32], - [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], - [48, 24], [48, 24], [40, 32]]) - elif config_name == 'channels-48-part3': - return SPADEConfigs(n_channels=[[48, 40, 32], - [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], - [48, 24], [40, 32], [48, 24]]) - elif config_name == 'channels-48-part4': - return SPADEConfigs(n_channels=[[48, 40, 32], - [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], - [48, 24], [40, 32], [40, 32]]) - elif config_name == 'channels-48-part5': - return SPADEConfigs(n_channels=[[48, 40, 32], - [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], - [40, 32], [48, 24], [48, 24]]) - elif config_name == 'channels-48-part6': - return SPADEConfigs(n_channels=[[48, 40, 32], - [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], - [40, 32], [48, 24], [40, 32]]) - elif config_name == 'channels-48-part7': - return SPADEConfigs(n_channels=[[48, 40, 32], - [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], - [40, 32], [40, 32], [48, 24]]) - elif config_name == 'channels-48-part8': - return SPADEConfigs(n_channels=[[48, 40, 32], - [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32], - [40, 32], [40, 32], [40, 32]]) - elif config_name == 'channels2-48': - return SPADEConfigs(n_channels=[[48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels2-48-part1': - return SPADEConfigs(n_channels=[[48], - [48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels2-48-part2': - return SPADEConfigs(n_channels=[[40], - [48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels2-48-part3': - return SPADEConfigs(n_channels=[[32], - [48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels2-48-part4': - return SPADEConfigs(n_channels=[[24], - [48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels2-48-part5': - return SPADEConfigs(n_channels=[[48], - [40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels2-48-part6': - return SPADEConfigs(n_channels=[[40], - [40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels2-48-part7': - return SPADEConfigs(n_channels=[[32], - [40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) - elif config_name == 'channels2-48-part8': - return SPADEConfigs(n_channels=[[24], - [40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], - [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]]) elif config_name == 'debug': return SPADEConfigs(n_channels=[[48], [48], [48], [48], [48], diff --git a/datasets/coco_generate_instance_map.py b/datasets/coco_generate_instance_map.py new file mode 100644 index 0000000..606d9ab --- /dev/null +++ b/datasets/coco_generate_instance_map.py @@ -0,0 +1,56 @@ +""" +Copyright (C) 2019 NVIDIA Corporation. All rights reserved. +Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +""" + +import argparse +import os + +import numpy as np +import skimage.io as io +import tqdm +from pycocotools.coco import COCO +from skimage.draw import polygon + +parser = argparse.ArgumentParser() +parser.add_argument('--annotation_file', type=str, default="./annotations/instances_train2017.json", + help="Path to the annocation file. It can be downloaded at http://images.cocodataset.org/annotations/annotations_trainval2017.zip. Should be either instances_train2017.json or instances_val2017.json") +parser.add_argument('--input_label_dir', type=str, default="./train_label/", + help="Path to the directory containing label maps. It can be downloaded at http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip") +parser.add_argument('--output_instance_dir', type=str, default="./train_inst/", + help="Path to the output directory of instance maps") +opt = parser.parse_args() +os.makedirs(opt.output_instance_dir, exist_ok=True) + +print("annotation file at {}".format(opt.annotation_file)) +print("input label maps at {}".format(opt.input_label_dir)) +print("output dir at {}".format(opt.output_instance_dir)) + +# initialize COCO api for instance annotations +coco = COCO(opt.annotation_file) + +# display COCO categories and supercategories +cats = coco.loadCats(coco.getCatIds()) +imgIds = coco.getImgIds(catIds=coco.getCatIds(cats)) +for ix, id in enumerate(tqdm.tqdm(imgIds)): + # if ix % 50 == 0: + # print("{} / {}".format(ix, len(imgIds))) + img_dict = coco.loadImgs(id)[0] + filename = img_dict["file_name"].replace("jpg", "png") + label_name = os.path.join(opt.input_label_dir, filename) + inst_name = os.path.join(opt.output_instance_dir, filename) + img = io.imread(label_name, as_gray=True) + + annIds = coco.getAnnIds(imgIds=id, catIds=[], iscrowd=None) + anns = coco.loadAnns(annIds) + count = 0 + for ann in anns: + if type(ann["segmentation"]) == list: + if "segmentation" in ann: + for seg in ann["segmentation"]: + poly = np.array(seg).reshape((int(len(seg) / 2), 2)) + rr, cc = polygon(poly[:, 1] - 1, poly[:, 0] - 1) + img[rr, cc] = count + count += 1 + + io.imsave(inst_name, img) diff --git a/distillers/base_resnet_distiller.py b/distillers/base_resnet_distiller.py index d9ad3ec..444de81 100644 --- a/distillers/base_resnet_distiller.py +++ b/distillers/base_resnet_distiller.py @@ -71,6 +71,11 @@ def __init__(self, opt): opt.student_netG, opt.norm, opt.student_dropout_rate, opt.init_type, opt.init_gain, self.gpu_ids, opt=opt) + if getattr(opt, 'sort_channels', False) and opt.restore_student_G_path is not None: + self.netG_student_tmp = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf, + opt.student_netG.replace('super_', ''), opt.norm, + opt.student_dropout_rate, opt.init_type, opt.init_gain, + self.gpu_ids, opt=opt) if hasattr(opt, 'distiller'): self.netG_pretrained = networks.define_G(opt.input_nc, opt.output_nc, opt.pretrained_ngf, opt.pretrained_netG, opt.norm, 0, @@ -221,6 +226,8 @@ def load_networks(self, verbose=True): util.load_network(self.netG_teacher, self.opt.restore_teacher_G_path, verbose) if self.opt.restore_student_G_path is not None: util.load_network(self.netG_student, self.opt.restore_student_G_path, verbose) + if hasattr(self, 'netG_student_tmp'): + util.load_network(self.netG_student_tmp, self.opt.restore_student_G_path, verbose) if self.opt.restore_D_path is not None: util.load_network(self.netD, self.opt.restore_D_path, verbose) if self.opt.restore_A_path is not None: diff --git a/docs/simplified_pipeline.md b/docs/simplified_pipeline.md new file mode 100644 index 0000000..150885a --- /dev/null +++ b/docs/simplified_pipeline.md @@ -0,0 +1,190 @@ +# Training Tutorial +## Prerequisites + +* Linux +* Python 3 +* CPU or NVIDIA GPU + CUDA CuDNN + +## Getting Started + +### Preparations + +Please refer to our [README](../README.md) for the installation, dataset preparations, and the evaluation (FID and mIoU). + +### Pipeline + +Below we show a simplified pipeline for compressing pix2pix and cycleGAN models. **We provide pre-trained models after each step. You could use the pretrained models to skip some steps.** For more training details, please refer to our codes. + +## Pix2pix Model Compression + +We will show the whole pipeline on `edges2shoes-r` dataset. You could change the dataset name to other datasets (`map2sat` and `cityscapes`). + +##### Train a MobileNet Teacher Model (The same as the full pipeline) + +Train a MobileNet-style teacher model from scratch. +```shell +bash scripts/pix2pix/edges2shoes-r_lite/train_mobile.sh +``` +We provide a pre-trained teacher for each dataset. You could download the pre-trained model by +```shell +python scripts/download_model.py --model pix2pix --task edges2shoes-r_lite --stage mobile +``` + +and test the model by + +```shell +bash scripts/pix2pix/edges2shoes-r_lite/test_mobile.sh +``` + +##### "Once-for-all" Network Training + +Train a "once-for-all" network from a pre-trained student model to search for the efficient architectures. + +```shell +bash scripts/pix2pix/edges2shoes-r_lite/train_supernet.sh +``` + +We provide a trained once-for-all network for each dataset. You could download the model by + +```shell +python scripts/download_model.py --model pix2pix --task edges2shoes-r_lite --stage supernet +``` + +##### Select the Best Model + +Evaluate all the candidate sub-networks given a specific configuration + +```shell +bash scripts/pix2pix/edges2shoes-r_lite/search.sh +``` + +The search result will be stored in the python `pickle` form. The pickle file is a python `list` object that stores all the candidate sub-networks information, whose element is a python `dict ` object in the form of + +``` +{'config_str': $config_str, 'macs': $macs, 'fid'/'mIoU': $fid_or_mIoU} +``` + +such as + +```python +{'config_str': '32_32_48_40_64_40_16_32', 'macs': 5761662976, 'fid': 30.594936138634836} +``` + +`'config_str'` is a channel configuration description to identify a specific subnet within the "once-for-all" network. + +To accelerate the search process, you may need to want to search the sub-networks on multiple GPUs. You could manually split the search space with [search.py](../search.py). All you need to do is add additional arguments `--split` and `--remainder`. For example, if you need to search the sub-networks with 2 GPUs, you could use the following commands: + +* On the first GPU: + + ```bash + python search.py --dataroot database/edges2shoes-r \ + --restore_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_G.pth \ + --output_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/pkls/result0.pkl \ + --ngf 64 --batch_size 32 \ + --config_set channels-64-pix2pix \ + --real_stat_path real_stat/edges2shoes-r_B.npz --load_in_memory --budget 6.5e9 \ + --split 2 --remainder 0 + ``` + +* On the second GPU: + + ```bash + python search.py --dataroot database/edges2shoes-r \ + --restore_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_G.pth \ + --output_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/pkls/result1.pkl \ + --ngf 64 --batch_size 32 \ + --config_set channels-64-pix2pix \ + --real_stat_path real_stat/edges2shoes-r_B.npz --load_in_memory --budget 6.5e9 \ + --split 2 --remainder 1 --gpu_ids 1 + ``` + +Then you could merge the search results with [merge.py](../merge.py) + +```bash +python merge.py --in_dir logs/pix2pix/edges2shoes-r_lite/supernet-stage2/pkls \ + --output_path logs/cycle_gan/horse2zebra/supernet +``` + +Once you get the search results, you could use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want. + +```shell +python select_arch.py --macs 6.5e9 --fid 32 \ + --pkl_path logs/pix2pix/edges2shoes-r/supernet/result.pkl +``` + +##### Fine-tuning the Best Model + +(Optional) Fine-tune a specific subnet within the pre-trained "once-for-all" network. To further improve the performance of your chosen subnet, you may need to fine-tune the subnet. For example, if you want to fine-tune a subnet within the "once-for-all" network with `'config_str': 32_32_48_40_64_40_16_32`, use the following command: + +```shell +bash scripts/pix2pix/edges2shoes-r_lite/finetune.sh 32_32_48_40_64_40_16_32 +``` + +##### Export the Model + +Extract a subnet from the "once-for-all" network. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `32_32_48_32_48_48_16_16`, then you can export the model by this command: + +```shell +bash scripts/pix2pix/edges2shoes-r_lite/export.sh 32_32_48_40_64_40_16_32 +``` + +## CycleGAN Model Compression + +The pipeline is almost identical to pix2pix. We will show the pipeline on `horse2zebra` dataset. + +##### Train a MobileNet Teacher Model + +Train a MobileNet-style teacher model from scratch. + +```shell +bash scripts/cycle_gan/horse2zebra_lite/train_mobile.sh +``` + +We provide a pre-trained teacher model for each dataset. You could download the model using + +```shell +python scripts/download_model.py --model cycle_gan --task horse2zebra_lite --stage mobile +``` + +and test the model by + +```shell +bash scripts/cycle_gan/horse2zebra_lite/test_mobile.sh +``` + +##### "Once-for-all" Network Training + +Train a "once-for-all" network from a pre-trained student model to search for the efficient architectures. + +```shell +bash scripts/cycle_gan/horse2zebra_lite/train_supernet.sh +``` + +We provide a pre-trained once-for-all network for each dataset. You could download the model by + +```shell +python scripts/download_model.py --model cycle_gan --task horse2zebra_lite --stage supernet +``` + +##### Select the Best Model + +Evaluate all the candidate sub-networks given a specific configuration + +```shell +bash scripts/cycle_gan/horse2zebra_lite/search.sh +``` +To support multi-GPU search, you could manually split the search space with additional arguments `--split` and `--remainder` and merge them with [merge.py](../merge.py), which is the same as pix2pix. + +You could also use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want. The usage is the same as pix2pix. + +##### Fine-tuning the Best Model + +During our experiments, we observe that fine-tuning the model on horse2zebra increases FID. **You may skip the fine-tuning.** + +##### Export the Model + +Extract a subnet from the supernet. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `24_16_32_16_32_64_16_24`, then you can export the model by this command: + +```shell +bash scripts/cycle_gan/horse2zebra_lite/export.sh 24_16_32_16_32_64_16_24 +``` diff --git a/docs/training_tutorial.md b/docs/training_tutorial.md index 9417253..f76916b 100644 --- a/docs/training_tutorial.md +++ b/docs/training_tutorial.md @@ -13,9 +13,9 @@ Please refer to our [README](../README.md) for the installation, dataset prepara ### Pipeline -Below we show the full pipeline for compressing pix2pix and cycleGAN models. **We provide pre-trained models after each step. You could use the pretrained models to skip some steps.** For more training details, please refer to [Appendix 6.1 Complete Pipeline](https://arxiv.org/pdf/2003.08936.pdf) of our paper. +Below we show the full pipeline for compressing pix2pix, cycleGAN and GauGAN models. **We provide pre-trained models after each step. You could use the pretrained models to skip some steps.** For more training details, please refer to [Appendix 6.1 Complete Pipeline](https://arxiv.org/pdf/2003.08936.pdf) of our paper. -In fact, several steps including "Train a MobileNet Teacher Model", "Pre-distillation", and "Fine-tuning the Best Model" may be omitted from the whole pipeline. We will provide a simplified pipeline soon. +In fact, several steps including "Train a MobileNet Teacher Model", "Pre-distillation", and "Fine-tuning the Best Model" may be omitted from the whole pipeline. Please check [simplified_pipeline.md](./simplified_pipeline.md) for our simplified pipeline. ## Pix2pix Model Compression @@ -94,7 +94,40 @@ such as `'config_str'` is a channel configuration description to identify a specific subnet within the "once-for-all" network. -You could use our auxiliary script `select_arch.py` to select the architecture you want. +To accelerate the search process, you may need to want to search the sub-networks on multiple GPUs. You could manually split the search space with [search.py](../search.py). All you need to do is add additional arguments `--split` and `--remainder`. For example, if you need to search the sub-networks with 2 GPUs, you could use the following commands: + +* On the first GPU: + + ```bash + python search.py --dataroot database/edges2shoes-r \ + --restore_G_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_G.pth \ + --output_path logs/pix2pix/edges2shoes-r/supernet/result.pkl \ + --ngf 48 --batch_size 32 \ + --config_set channels-48 \ + --real_stat_path real_stat/edges2shoes-r_B.npz \ + --split 2 --remainder 0 + ``` + +* On the second GPU: + + ```bash + python search.py --dataroot database/edges2shoes-r \ + --restore_G_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_G.pth \ + --output_path logs/pix2pix/edges2shoes-r/supernet/result.pkl \ + --ngf 48 --batch_size 32 \ + --config_set channels-48 \ + --real_stat_path real_stat/edges2shoes-r_B.npz \ + --split 2 --remainder 1 --gpu_ids 1 + ``` + +Then you could merge the search results with [merge.py](../merge.py) + +```bash +python merge.py --in_dir logs/pix2pix/edges2shoes-r_lite/supernet-stage2/pkls \ + --output_path logs/cycle_gan/horse2zebra/supernet +``` + +You could use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want. ```shell python select_arch.py --macs 5.7e9 --fid 30 \ # macs <= 5.7e9(10x), fid >= 30 @@ -111,7 +144,7 @@ bash scripts/pix2pix/edges2shoes-r/finetune.sh 32_32_48_32_48_48_16_16 ##### Export the Model -Extract a subnet from the "once-for-all" network. We provide a code `export.py` to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `32_32_48_32_48_48_16_16`, then you can export the model by this command: +Extract a subnet from the "once-for-all" network. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `32_32_48_32_48_48_16_16`, then you can export the model by this command: ```shell bash scripts/pix2pix/edges2shoes-r/export.sh 32_32_48_32_48_48_16_16 @@ -186,7 +219,9 @@ Evaluate all the candidate sub-networks given a specific configuration ```shell bash scripts/cycle_gan/horse2zebra/search.sh ``` -You could also use our auxiliary script `select_arch.py` to select the architecture you want. The usage is the same as pix2pix. +To support multi-GPU search, you could manually split the search space with additional arguments `--split` and `--remainder` and merge them with [merge.py](../merge.py), which is the same as pix2pix. + +You could also use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want. The usage is the same as pix2pix. ##### Fine-tuning the Best Model @@ -200,7 +235,7 @@ During our experiments, we observe that fine-tuning the model on horse2zebra inc ##### Export the Model -Extract a subnet from the supernet. We provide a code `export.py` to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `16_16_32_16_32_32_16_16`, then you can export the model by this command: +Extract a subnet from the supernet. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `16_16_32_16_32_32_16_16`, then you can export the model by this command: ```shell bash scripts/cycle_gan/horse2zebra/export.sh 16_16_32_16_32_32_16_16 @@ -270,21 +305,23 @@ Evaluate all the candidate sub-networks given a specific configuration (e.g., MA ```shell bash scripts/gaugan/cityscapes/search.sh ``` -You could also use our auxiliary script `select_arch.py` to select the architecture you want. The usage is the same as pix2pix. +To support multi-GPU search, you could manually split the search space with additional arguments `--split` and `--remainder` and merge them with [merge.py](../merge.py), which is the same as pix2pix. + +You could also use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want. The usage is the same as pix2pix. ##### Fine-tuning the Best Model (Optional) Fine-tune a specific subnet within the pre-trained "once-for-all" network. For example, if you want to fine-tune a subnet within the "once-for-all" network with `'config_str': 32_32_48_32_48_48_16_16`, try the following command: ```shell -bash scripts/gaugan/cityscapes/finetune.sh 16_16_32_16_32_32_16_16 +bash scripts/gaugan/cityscapes/finetune.sh 32_40_40_32_40_24_32_24 ``` ##### Export the Model -Extract a subnet from the supernet. We provide a code `export.py` to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `16_16_32_16_32_32_16_16`, then you can export the model by this command: +Extract a subnet from the supernet. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `32_40_40_32_40_24_32_24`, then you can export the model by this command: ```shell -bash scripts/gaugan/cityscapes/export.sh 16_16_32_16_32_32_16_16 +bash scripts/gaugan/cityscapes/export.sh 32_40_40_32_40_24_32_24 ``` diff --git a/scripts/cycle_gan/horse2zebra_lite/export.sh b/scripts/cycle_gan/horse2zebra_lite/export.sh new file mode 100644 index 0000000..b90c746 --- /dev/null +++ b/scripts/cycle_gan/horse2zebra_lite/export.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +python export.py \ + --input_path logs/cycle_gan/horse2zebra_lite/supernet-stage2/checkpoints/latest_net_G.pth \ + --output_path logs/cycle_gan/horse2zebra_lite/compressed/latest_net_G.pth \ + --ngf 64 --config_str $1 diff --git a/scripts/cycle_gan/horse2zebra_lite/latency_compressed.sh b/scripts/cycle_gan/horse2zebra_lite/latency_compressed.sh new file mode 100644 index 0000000..902af05 --- /dev/null +++ b/scripts/cycle_gan/horse2zebra_lite/latency_compressed.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +python latency.py --dataroot database/horse2zebra/valA \ + --dataset_mode single \ + --results_dir results-pretrained/cycle_gan/horse2zebra_lite/compressed \ + --config_str 24_16_32_16_32_64_16_24 \ + --restore_G_path pretrained/cycle_gan/horse2zebra_lite/compressed/latest_net_G.pth \ + --need_profile \ + --real_stat_path real_stat/horse2zebra_B.npz diff --git a/scripts/cycle_gan/horse2zebra_lite/search.sh b/scripts/cycle_gan/horse2zebra_lite/search.sh new file mode 100644 index 0000000..caf461b --- /dev/null +++ b/scripts/cycle_gan/horse2zebra_lite/search.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +python search.py --dataroot database/horse2zebra/valA \ + --dataset_mode single \ + --restore_G_path logs/cycle_gan/horse2zebra_lite/supernet-stage2/checkpoints/latest_net_G.pth \ + --output_path logs/cycle_gan/horse2zebra_lite/supernet-stage2/result.pkl \ + --ngf 64 --batch_size 32 \ + --config_set channels-64-cycleGAN \ + --real_stat_path real_stat/horse2zebra_B.npz --budget 3.6e9 diff --git a/scripts/cycle_gan/horse2zebra_lite/test_compressed.sh b/scripts/cycle_gan/horse2zebra_lite/test_compressed.sh new file mode 100644 index 0000000..58ccb04 --- /dev/null +++ b/scripts/cycle_gan/horse2zebra_lite/test_compressed.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +python test.py --dataroot database/horse2zebra/valA \ + --dataset_mode single \ + --results_dir results-pretrained/cycle_gan/horse2zebra_lite/compressed \ + --config_str 24_16_32_16_32_64_16_24 \ + --restore_G_path pretrained/cycle_gan/horse2zebra_lite/compressed/latest_net_G.pth \ + --need_profile \ + --real_stat_path real_stat/horse2zebra_B.npz diff --git a/scripts/cycle_gan/horse2zebra_lite/train_mobile.sh b/scripts/cycle_gan/horse2zebra_lite/train_mobile.sh new file mode 100644 index 0000000..3769974 --- /dev/null +++ b/scripts/cycle_gan/horse2zebra_lite/train_mobile.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +python train.py --dataroot database/horse2zebra \ + --model cycle_gan \ + --log_dir logs/cycle_gan/horse2zebra_lite/mobile \ + --real_stat_A_path real_stat/horse2zebra_A.npz \ + --real_stat_B_path real_stat/horse2zebra_B.npz diff --git a/scripts/cycle_gan/horse2zebra_lite/train_supernet.sh b/scripts/cycle_gan/horse2zebra_lite/train_supernet.sh new file mode 100644 index 0000000..6e3bfbc --- /dev/null +++ b/scripts/cycle_gan/horse2zebra_lite/train_supernet.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +python train_supernet.py --dataroot database/horse2zebra \ + --dataset_mode unaligned \ + --supernet resnet \ + --log_dir logs/cycle_gan/horse2zebra_lite/supernet-stage1 \ + --gan_mode lsgan \ + --student_ngf 64 --ndf 64 \ + --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \ + --restore_student_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \ + --restore_D_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/checkpoints/latest_net_D.pth \ + --real_stat_path real_stat/horse2zebra_B.npz \ + --lambda_recon 10 --lambda_distill 0.01 \ + --nepochs 50 --nepochs_decay 50 \ + --save_epoch_freq 20 \ + --config_set channels-64-cycleGAN-stage1 --sort_channels +python train_supernet.py --dataroot database/horse2zebra \ + --dataset_mode unaligned \ + --supernet resnet \ + --log_dir logs/cycle_gan/horse2zebra_lite/supernet-stage2 \ + --gan_mode lsgan \ + --student_ngf 64 --ndf 64 \ + --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \ + --restore_student_G_path logs/cycle_gan/horse2zebra_lite/supernet-stage1/checkpoints/latest_net_G.pth \ + --restore_D_path logs/cycle_gan/horse2zebra_lite/supernet-stage2/checkpoints/latest_net_D.pth \ + --real_stat_path real_stat/horse2zebra_B.npz \ + --lambda_recon 10 --lambda_distill 0.01 \ + --nepochs 200 --nepochs_decay 200 \ + --save_epoch_freq 20 \ + --config_set channels-64-cycleGAN diff --git a/scripts/download_model.py b/scripts/download_model.py index cc9ca22..5e8d4a9 100644 --- a/scripts/download_model.py +++ b/scripts/download_model.py @@ -6,12 +6,11 @@ def check(opt): if opt.model == 'pix2pix': - assert opt.task in ['edges2shoes-r', 'map2sat', 'cityscapes'] + assert opt.task in ['edges2shoes-r', 'map2sat', 'cityscapes', 'edges2shoes-r_lite'] elif opt.model == 'cycle_gan': - assert opt.task in ['horse2zebra'] + assert opt.task in ['horse2zebra', 'horse2zebra_lite'] elif opt.model == 'gaugan': assert opt.task in ['cityscapes'] - assert opt.stage in ['compressed', 'full'] else: raise NotImplementedError('Unsupported model [%s]!' % opt.model) diff --git a/scripts/pix2pix/edges2shoes-r/export.sh b/scripts/pix2pix/edges2shoes-r/export.sh index aca2550..e67d7ee 100644 --- a/scripts/pix2pix/edges2shoes-r/export.sh +++ b/scripts/pix2pix/edges2shoes-r/export.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash python export.py \ --input_path logs/pix2pix/edges2shoes-r/finetune/checkpoints/latest_net_G.pth \ - --output_path logs/pix2pix/edges2shoes-r/compressed/compressed_net_G.pth \ + --output_path logs/pix2pix/edges2shoes-r/compressed/latest_net_G.pth \ --config_str $1 diff --git a/scripts/pix2pix/edges2shoes-r_lite/export.sh b/scripts/pix2pix/edges2shoes-r_lite/export.sh new file mode 100644 index 0000000..f840677 --- /dev/null +++ b/scripts/pix2pix/edges2shoes-r_lite/export.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +python export.py \ + --input_path logs/pix2pix/edges2shoes-r_lite/finetune/checkpoints/latest_net_G.pth \ + --output_path logs/pix2pix/edges2shoes-r_lite/compressed/latest_net_G.pth \ + --ngf 64 --config_str $1 diff --git a/scripts/pix2pix/edges2shoes-r_lite/finetune.sh b/scripts/pix2pix/edges2shoes-r_lite/finetune.sh new file mode 100644 index 0000000..e8addb2 --- /dev/null +++ b/scripts/pix2pix/edges2shoes-r_lite/finetune.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +python train_supernet.py --dataroot database/edges2shoes-r \ + --supernet resnet \ + --log_dir logs/pix2pix/edges2shoes-r_lite/finetune \ + --batch_size 4 \ + --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \ + --restore_student_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_G.pth \ + --restore_D_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_D.pth \ + --real_stat_path real_stat/edges2shoes-r_B.npz \ + --nepochs 5 --nepochs_decay 15 \ + --teacher_ngf 64 --student_ngf 64 \ + --config_str $1 diff --git a/scripts/pix2pix/edges2shoes-r_lite/latency_compressed.sh b/scripts/pix2pix/edges2shoes-r_lite/latency_compressed.sh new file mode 100644 index 0000000..43e36f1 --- /dev/null +++ b/scripts/pix2pix/edges2shoes-r_lite/latency_compressed.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +python latency.py --dataroot database/edges2shoes-r \ + --results_dir results-pretrained/pix2pix/edges2shoes-r_lite/compressed \ + --restore_G_path pretrained/pix2pix/edges2shoes-r_lite/compressed/latest_net_G.pth \ + --config_str 32_32_48_40_64_40_16_32 \ + --real_stat_path real_stat/edges2shoes-r_B.npz \ + --need_profile --num_test 500 diff --git a/scripts/pix2pix/edges2shoes-r_lite/search.sh b/scripts/pix2pix/edges2shoes-r_lite/search.sh new file mode 100644 index 0000000..17b2e00 --- /dev/null +++ b/scripts/pix2pix/edges2shoes-r_lite/search.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +python search.py --dataroot database/edges2shoes-r \ + --restore_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_G.pth \ + --output_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/result.pkl \ + --ngf 64 --batch_size 32 \ + --config_set channels-64-pix2pix \ + --real_stat_path real_stat/edges2shoes-r_B.npz --load_in_memory --budget 6.5e9 diff --git a/scripts/pix2pix/edges2shoes-r_lite/test_compressed.sh b/scripts/pix2pix/edges2shoes-r_lite/test_compressed.sh new file mode 100644 index 0000000..879d13d --- /dev/null +++ b/scripts/pix2pix/edges2shoes-r_lite/test_compressed.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +python test.py --dataroot database/edges2shoes-r \ + --results_dir results-pretrained/pix2pix/edges2shoes-r_lite/compressed \ + --restore_G_path pretrained/pix2pix/edges2shoes-r_lite/compressed/latest_net_G.pth \ + --config_str 32_32_48_40_64_40_16_32 \ + --real_stat_path real_stat/edges2shoes-r_B.npz \ + --need_profile --num_test 500 diff --git a/scripts/pix2pix/edges2shoes-r_lite/train_mobile.sh b/scripts/pix2pix/edges2shoes-r_lite/train_mobile.sh new file mode 100644 index 0000000..86268e5 --- /dev/null +++ b/scripts/pix2pix/edges2shoes-r_lite/train_mobile.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +python train.py --dataroot database/edges2shoes-r \ + --model pix2pix \ + --log_dir logs/pix2pix/edges2shoes-r_lite/mobile \ + --real_stat_path real_stat/edges2shoes-r_B.npz diff --git a/scripts/pix2pix/edges2shoes-r_lite/train_supernet.sh b/scripts/pix2pix/edges2shoes-r_lite/train_supernet.sh new file mode 100644 index 0000000..b7c1335 --- /dev/null +++ b/scripts/pix2pix/edges2shoes-r_lite/train_supernet.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +python train_supernet.py --dataroot database/edges2shoes-r \ + --supernet resnet \ + --log_dir logs/pix2pix/edges2shoes-r_lite/supernet-stage1 \ + --batch_size 4 \ + --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \ + --restore_student_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \ + --restore_D_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \ + --real_stat_path real_stat/edges2shoes-r_B.npz \ + --nepochs 2 --nepochs_decay 8 \ + --teacher_ngf 64 --student_ngf 64 \ + --config_set channels-64-pix2pix-stage1 --sort_channels + +python train_supernet.py --dataroot database/edges2shoes-r \ + --supernet resnet \ + --log_dir logs/pix2pix/edges2shoes-r_lite/supernet-stage2 \ + --batch_size 4 \ + --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \ + --restore_student_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage1/checkpoints/latest_net_G.pth \ + --restore_D_path logs/pix2pix/edges2shoes-r_lite/supernet-stage1/checkpoints/latest_net_G.pth \ + --real_stat_path real_stat/edges2shoes-r_B.npz \ + --nepochs 10 --nepochs_decay 30 \ + --teacher_ngf 64 --student_ngf 64 \ + --config_set channels-64-pix2pix diff --git a/supernets/resnet_supernet.py b/supernets/resnet_supernet.py index 5d1ab49..0defae9 100644 --- a/supernets/resnet_supernet.py +++ b/supernets/resnet_supernet.py @@ -2,8 +2,8 @@ import os import torch -import torch.nn.functional as F from torch import nn +from torch.nn import functional as F from tqdm import tqdm from configs import decode_config @@ -13,6 +13,7 @@ from metric import get_fid, get_mIoU from models.modules.super_modules import SuperConv2d from utils import util +from utils.weight_transfer import load_pretrained_weight class ResnetSupernet(BaseResnetDistiller): @@ -20,6 +21,8 @@ class ResnetSupernet(BaseResnetDistiller): def modify_commandline_options(parser, is_train): assert is_train parser = super(ResnetSupernet, ResnetSupernet).modify_commandline_options(parser, is_train) + parser.add_argument('--sort_channels', action='store_true', + help='whether to sort the channels of student G by L1 norm') parser.set_defaults(norm='instance', student_netG='super_mobile_resnet_9blocks', dataset_mode='aligned', log_dir='logs/supernet') return parser @@ -170,3 +173,11 @@ def evaluate_model(self, step): def test(self, config): with torch.no_grad(): self.forward(config) + + def load_networks(self, verbose=True): + super(ResnetSupernet, self).load_networks() + if hasattr(self, 'netG_student_tmp'): + load_pretrained_weight(self.opt.student_netG.replace('super_', ''), self.opt.student_netG, + self.netG_student_tmp, self.netG_student, + self.opt.student_ngf, self.opt.student_ngf) + del self.netG_student_tmp diff --git a/utils/weight_transfer.py b/utils/weight_transfer.py index 38d54c0..63f2d08 100644 --- a/utils/weight_transfer.py +++ b/utils/weight_transfer.py @@ -3,12 +3,14 @@ from models.modules.mobile_modules import SeparableConv2d from models.modules.resnet_architecture.mobile_resnet_generator import MobileResnetBlock from models.modules.resnet_architecture.resnet_generator import ResnetBlock +from models.modules.resnet_architecture.super_mobile_resnet_generator import SuperMobileResnetBlock from models.modules.spade_architecture.mobile_spade_generator import MobileSPADEGenerator, MobileSPADEResnetBlock, \ MobileSPADE +from models.modules.super_modules import SuperConv2d, SuperConvTranspose2d, SuperSeparableConv2d def transfer_Conv2d(m1, m2, input_index=None, output_index=None): - assert isinstance(m1, nn.Conv2d) and isinstance(m2, nn.Conv2d) + assert isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d)) if m1.out_channels == 3: # If this is the last convolution assert input_index is not None m2.weight.data = m1.weight.data[:, input_index].clone() @@ -42,7 +44,7 @@ def transfer_Conv2d(m1, m2, input_index=None, output_index=None): def transfer_ConvTranspose2d(m1, m2, input_index=None, output_index=None): - assert isinstance(m1, nn.ConvTranspose2d) and isinstance(m2, nn.ConvTranspose2d) + assert isinstance(m1, nn.ConvTranspose2d) and isinstance(m2, (nn.ConvTranspose2d, SuperConvTranspose2d)) assert output_index is None p = m1.weight.data if input_index is None: @@ -60,7 +62,7 @@ def transfer_ConvTranspose2d(m1, m2, input_index=None, output_index=None): def transfer_SeparableConv2d(m1, m2, input_index=None, output_index=None): - assert isinstance(m1, SeparableConv2d) and isinstance(m2, SeparableConv2d) + assert isinstance(m1, SeparableConv2d) and isinstance(m2, (SeparableConv2d, SuperSeparableConv2d)) dw1, pw1 = m1.conv[0], m1.conv[2] dw2, pw2 = m2.conv[0], m2.conv[2] @@ -77,7 +79,7 @@ def transfer_SeparableConv2d(m1, m2, input_index=None, output_index=None): def transfer_MobileResnetBlock(m1, m2, input_index=None, output_index=None): - assert isinstance(m1, MobileResnetBlock) and isinstance(m2, MobileResnetBlock) + assert isinstance(m1, MobileResnetBlock) and isinstance(m2, (MobileResnetBlock, SuperMobileResnetBlock)) assert output_index is None idxs = transfer(m1.conv_block[1], m2.conv_block[1], input_index=input_index) idxs = transfer(m1.conv_block[6], m2.conv_block[6], input_index=idxs, output_index=input_index) @@ -125,7 +127,6 @@ def transfer_MobileSPADEResnetBlock(m1, m2, input_index=None, output_index=None) def transfer(m1, m2, input_index=None, output_index=None): - assert type(m1) == type(m2) if isinstance(m1, nn.Conv2d): return transfer_Conv2d(m1, m2, input_index, output_index) elif isinstance(m1, nn.ConvTranspose2d): @@ -145,7 +146,6 @@ def transfer(m1, m2, input_index=None, output_index=None): def load_pretrained_weight(model1, model2, netA, netB, ngf1, ngf2): - assert model1 == model2 assert ngf1 >= ngf2 if isinstance(netA, nn.DataParallel): @@ -162,7 +162,7 @@ def load_pretrained_weight(model1, model2, netA, netB, ngf1, ngf2): assert len(net1.model) == len(net2.model) for i in range(28): m1, m2 = net1.model[i], net2.model[i] - assert type(m1) == type(m2) + # assert type(m1) == type(m2) if isinstance(m1, (nn.Conv2d, nn.ConvTranspose2d, MobileResnetBlock)): index = transfer(m1, m2, index) elif model1 == 'resnet_9blocks':