diff --git a/README.md b/README.md
index 1852509..02ddb4e 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,8 @@
# GAN Compression
### [project](https://hanlab.mit.edu/projects/gancompression/) | [paper](https://arxiv.org/abs/2003.08936) | [demo](https://www.youtube.com/playlist?list=PL80kAHvQbh-r5R8UmXhQK1ndqRvPNw_ex)
+**[NEW!]** The simplified pipeline of GAN Compression is released! Check the [tutorial](docs/simplified_pipeline.md) for the pipeline.
+
**[NEW!]** GauGAN training code and tutorial is released! Check the [tutorial](docs/training_tutorial.md) to compress GauGAN.
**[NEW!]** Correct metric naming and update the evaluation protocol. Support MACs budget for searching.
@@ -9,8 +11,6 @@
**[NEW!]** The [tutorial](docs/training_tutorial.md) of compression is released! Check the [overview](docs/overview.md) for better understanding our codebase.
-**[NEW!]** The PyTorch implementation of a general conditional GAN Compression framework is released.
-
![teaser](imgs/teaser.png)*We introduce GAN Compression, a general-purpose method for compressing conditional GANs. Our method reduces the computation of widely-used conditional GAN models, including pix2pix, CycleGAN, and GauGAN, by 9-21x while preserving the visual fidelity. Our method is effective for a wide range of generator architectures, learning objectives, and both paired and unpaired settings.*
GAN Compression: Efficient Architectures for Interactive Conditional GANs
@@ -189,7 +189,7 @@ PyTorch Colab notebook: [CycleGAN](https://colab.research.google.com/github/mit-
### Cityscapes Dataset
-For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script `datasets/prepare_cityscapes_dataset.py` to preprocess it. You need to download `gtFine_trainvaltest.zip` and `leftImg8bit_trainvaltest.zip` and unzip them in the same folder. For example, you may put `gtFine` and `leftImg8bit` in `database/cityscapes-origin`. You need to prepare the dataset with the following commands:
+For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script [prepare_cityscapes_dataset.py](datasets/prepare_cityscapes_dataset.py) to preprocess it. You need to download `gtFine_trainvaltest.zip` and `leftImg8bit_trainvaltest.zip` and unzip them in the same folder. For example, you may put `gtFine` and `leftImg8bit` in `database/cityscapes-origin`. You need to prepare the dataset with the following commands:
```shell
python datasets/get_trainIds.py database/cityscapes-origin/gtFine/
@@ -210,7 +210,7 @@ Please refer to our training [tutorial](docs/training_tutorial.md) on how to tra
### FID Computation
-To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script `get_real_stat.py` to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:
+To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script [get_real_stat.py](./get_real_stat.py) to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:
```shell
python get_real_stat.py \
diff --git a/configs/resnet_configs.py b/configs/resnet_configs.py
index d5d463a..8c91e2b 100644
--- a/configs/resnet_configs.py
+++ b/configs/resnet_configs.py
@@ -71,6 +71,22 @@ def get_configs(config_name):
return ResnetConfigs(n_channels=[[32, 24, 16], [32, 24, 16], [32, 24, 16],
[32, 24, 16], [32, 24, 16], [32, 24, 16],
[32, 24, 16], [32, 24, 16]])
+ elif config_name == 'channels-64-cycleGAN-stage1':
+ return ResnetConfigs(n_channels=[[64, 48, 32], [64, 48, 32], [64, 48, 32],
+ [64, 48, 32], [64, 48, 32], [64, 48, 32],
+ [64, 48, 32], [64, 48, 32]])
+ elif config_name == 'channels-64-cycleGAN':
+ return ResnetConfigs(n_channels=[[64, 48, 32, 24, 16], [64, 48, 32, 24, 16], [64, 48, 32, 24, 16],
+ [64, 48, 32, 24, 16], [64, 48, 32, 24, 16], [64, 48, 32, 24, 16],
+ [64, 48, 32, 24, 16], [64, 48, 32, 24, 16]])
+ elif config_name == 'channels-64-pix2pix-stage1':
+ return ResnetConfigs(n_channels=[[64, 48], [64, 48], [64, 48],
+ [64, 48], [64, 48], [64, 48],
+ [64, 48, 32], [64, 48, 32]])
+ elif config_name == 'channels-64-pix2pix':
+ return ResnetConfigs(n_channels=[[64, 48, 32], [64, 48, 32], [64, 48, 40, 32],
+ [64, 48, 40, 32], [64, 48, 40, 32], [64, 48, 40, 32],
+ [64, 48, 32, 24, 16], [64, 48, 32, 24, 16]])
elif config_name == 'debug':
return ResnetConfigs(n_channels=[[48, 32], [48, 32], [48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32],
diff --git a/configs/spade_configs.py b/configs/spade_configs.py
index 537fed0..988ff30 100644
--- a/configs/spade_configs.py
+++ b/configs/spade_configs.py
@@ -29,7 +29,7 @@ def smallest(self):
ret['channels'].append(min(n_channel))
return ret
- def all_configs(self, split=1, remainder=0):
+ def all_configs(self):
def yield_channels(i):
if i == len(self.n_channels):
@@ -39,9 +39,8 @@ def yield_channels(i):
for after_channels in yield_channels(i + 1):
yield [n] + after_channels
- for i, channels in enumerate(yield_channels(0)):
- if i % split == remainder:
- yield {'channels': channels}
+ for channels in yield_channels(0):
+ yield {'channels': channels}
def __call__(self, name):
assert name in ('largest', 'smallest')
@@ -64,74 +63,6 @@ def get_configs(config_name):
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels-48-part1':
- return SPADEConfigs(n_channels=[[48, 40, 32],
- [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
- [48, 24], [48, 24], [48, 24]])
- elif config_name == 'channels-48-part2':
- return SPADEConfigs(n_channels=[[48, 40, 32],
- [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
- [48, 24], [48, 24], [40, 32]])
- elif config_name == 'channels-48-part3':
- return SPADEConfigs(n_channels=[[48, 40, 32],
- [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
- [48, 24], [40, 32], [48, 24]])
- elif config_name == 'channels-48-part4':
- return SPADEConfigs(n_channels=[[48, 40, 32],
- [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
- [48, 24], [40, 32], [40, 32]])
- elif config_name == 'channels-48-part5':
- return SPADEConfigs(n_channels=[[48, 40, 32],
- [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
- [40, 32], [48, 24], [48, 24]])
- elif config_name == 'channels-48-part6':
- return SPADEConfigs(n_channels=[[48, 40, 32],
- [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
- [40, 32], [48, 24], [40, 32]])
- elif config_name == 'channels-48-part7':
- return SPADEConfigs(n_channels=[[48, 40, 32],
- [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
- [40, 32], [40, 32], [48, 24]])
- elif config_name == 'channels-48-part8':
- return SPADEConfigs(n_channels=[[48, 40, 32],
- [48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
- [40, 32], [40, 32], [40, 32]])
- elif config_name == 'channels2-48':
- return SPADEConfigs(n_channels=[[48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels2-48-part1':
- return SPADEConfigs(n_channels=[[48],
- [48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels2-48-part2':
- return SPADEConfigs(n_channels=[[40],
- [48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels2-48-part3':
- return SPADEConfigs(n_channels=[[32],
- [48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels2-48-part4':
- return SPADEConfigs(n_channels=[[24],
- [48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels2-48-part5':
- return SPADEConfigs(n_channels=[[48],
- [40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels2-48-part6':
- return SPADEConfigs(n_channels=[[40],
- [40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels2-48-part7':
- return SPADEConfigs(n_channels=[[32],
- [40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
- elif config_name == 'channels2-48-part8':
- return SPADEConfigs(n_channels=[[24],
- [40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
- [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'debug':
return SPADEConfigs(n_channels=[[48],
[48], [48], [48], [48],
diff --git a/datasets/coco_generate_instance_map.py b/datasets/coco_generate_instance_map.py
new file mode 100644
index 0000000..606d9ab
--- /dev/null
+++ b/datasets/coco_generate_instance_map.py
@@ -0,0 +1,56 @@
+"""
+Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
+Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+"""
+
+import argparse
+import os
+
+import numpy as np
+import skimage.io as io
+import tqdm
+from pycocotools.coco import COCO
+from skimage.draw import polygon
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--annotation_file', type=str, default="./annotations/instances_train2017.json",
+ help="Path to the annocation file. It can be downloaded at http://images.cocodataset.org/annotations/annotations_trainval2017.zip. Should be either instances_train2017.json or instances_val2017.json")
+parser.add_argument('--input_label_dir', type=str, default="./train_label/",
+ help="Path to the directory containing label maps. It can be downloaded at http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip")
+parser.add_argument('--output_instance_dir', type=str, default="./train_inst/",
+ help="Path to the output directory of instance maps")
+opt = parser.parse_args()
+os.makedirs(opt.output_instance_dir, exist_ok=True)
+
+print("annotation file at {}".format(opt.annotation_file))
+print("input label maps at {}".format(opt.input_label_dir))
+print("output dir at {}".format(opt.output_instance_dir))
+
+# initialize COCO api for instance annotations
+coco = COCO(opt.annotation_file)
+
+# display COCO categories and supercategories
+cats = coco.loadCats(coco.getCatIds())
+imgIds = coco.getImgIds(catIds=coco.getCatIds(cats))
+for ix, id in enumerate(tqdm.tqdm(imgIds)):
+ # if ix % 50 == 0:
+ # print("{} / {}".format(ix, len(imgIds)))
+ img_dict = coco.loadImgs(id)[0]
+ filename = img_dict["file_name"].replace("jpg", "png")
+ label_name = os.path.join(opt.input_label_dir, filename)
+ inst_name = os.path.join(opt.output_instance_dir, filename)
+ img = io.imread(label_name, as_gray=True)
+
+ annIds = coco.getAnnIds(imgIds=id, catIds=[], iscrowd=None)
+ anns = coco.loadAnns(annIds)
+ count = 0
+ for ann in anns:
+ if type(ann["segmentation"]) == list:
+ if "segmentation" in ann:
+ for seg in ann["segmentation"]:
+ poly = np.array(seg).reshape((int(len(seg) / 2), 2))
+ rr, cc = polygon(poly[:, 1] - 1, poly[:, 0] - 1)
+ img[rr, cc] = count
+ count += 1
+
+ io.imsave(inst_name, img)
diff --git a/distillers/base_resnet_distiller.py b/distillers/base_resnet_distiller.py
index d9ad3ec..444de81 100644
--- a/distillers/base_resnet_distiller.py
+++ b/distillers/base_resnet_distiller.py
@@ -71,6 +71,11 @@ def __init__(self, opt):
opt.student_netG, opt.norm, opt.student_dropout_rate,
opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)
+ if getattr(opt, 'sort_channels', False) and opt.restore_student_G_path is not None:
+ self.netG_student_tmp = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
+ opt.student_netG.replace('super_', ''), opt.norm,
+ opt.student_dropout_rate, opt.init_type, opt.init_gain,
+ self.gpu_ids, opt=opt)
if hasattr(opt, 'distiller'):
self.netG_pretrained = networks.define_G(opt.input_nc, opt.output_nc, opt.pretrained_ngf,
opt.pretrained_netG, opt.norm, 0,
@@ -221,6 +226,8 @@ def load_networks(self, verbose=True):
util.load_network(self.netG_teacher, self.opt.restore_teacher_G_path, verbose)
if self.opt.restore_student_G_path is not None:
util.load_network(self.netG_student, self.opt.restore_student_G_path, verbose)
+ if hasattr(self, 'netG_student_tmp'):
+ util.load_network(self.netG_student_tmp, self.opt.restore_student_G_path, verbose)
if self.opt.restore_D_path is not None:
util.load_network(self.netD, self.opt.restore_D_path, verbose)
if self.opt.restore_A_path is not None:
diff --git a/docs/simplified_pipeline.md b/docs/simplified_pipeline.md
new file mode 100644
index 0000000..150885a
--- /dev/null
+++ b/docs/simplified_pipeline.md
@@ -0,0 +1,190 @@
+# Training Tutorial
+## Prerequisites
+
+* Linux
+* Python 3
+* CPU or NVIDIA GPU + CUDA CuDNN
+
+## Getting Started
+
+### Preparations
+
+Please refer to our [README](../README.md) for the installation, dataset preparations, and the evaluation (FID and mIoU).
+
+### Pipeline
+
+Below we show a simplified pipeline for compressing pix2pix and cycleGAN models. **We provide pre-trained models after each step. You could use the pretrained models to skip some steps.** For more training details, please refer to our codes.
+
+## Pix2pix Model Compression
+
+We will show the whole pipeline on `edges2shoes-r` dataset. You could change the dataset name to other datasets (`map2sat` and `cityscapes`).
+
+##### Train a MobileNet Teacher Model (The same as the full pipeline)
+
+Train a MobileNet-style teacher model from scratch.
+```shell
+bash scripts/pix2pix/edges2shoes-r_lite/train_mobile.sh
+```
+We provide a pre-trained teacher for each dataset. You could download the pre-trained model by
+```shell
+python scripts/download_model.py --model pix2pix --task edges2shoes-r_lite --stage mobile
+```
+
+and test the model by
+
+```shell
+bash scripts/pix2pix/edges2shoes-r_lite/test_mobile.sh
+```
+
+##### "Once-for-all" Network Training
+
+Train a "once-for-all" network from a pre-trained student model to search for the efficient architectures.
+
+```shell
+bash scripts/pix2pix/edges2shoes-r_lite/train_supernet.sh
+```
+
+We provide a trained once-for-all network for each dataset. You could download the model by
+
+```shell
+python scripts/download_model.py --model pix2pix --task edges2shoes-r_lite --stage supernet
+```
+
+##### Select the Best Model
+
+Evaluate all the candidate sub-networks given a specific configuration
+
+```shell
+bash scripts/pix2pix/edges2shoes-r_lite/search.sh
+```
+
+The search result will be stored in the python `pickle` form. The pickle file is a python `list` object that stores all the candidate sub-networks information, whose element is a python `dict ` object in the form of
+
+```
+{'config_str': $config_str, 'macs': $macs, 'fid'/'mIoU': $fid_or_mIoU}
+```
+
+such as
+
+```python
+{'config_str': '32_32_48_40_64_40_16_32', 'macs': 5761662976, 'fid': 30.594936138634836}
+```
+
+`'config_str'` is a channel configuration description to identify a specific subnet within the "once-for-all" network.
+
+To accelerate the search process, you may need to want to search the sub-networks on multiple GPUs. You could manually split the search space with [search.py](../search.py). All you need to do is add additional arguments `--split` and `--remainder`. For example, if you need to search the sub-networks with 2 GPUs, you could use the following commands:
+
+* On the first GPU:
+
+ ```bash
+ python search.py --dataroot database/edges2shoes-r \
+ --restore_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_G.pth \
+ --output_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/pkls/result0.pkl \
+ --ngf 64 --batch_size 32 \
+ --config_set channels-64-pix2pix \
+ --real_stat_path real_stat/edges2shoes-r_B.npz --load_in_memory --budget 6.5e9 \
+ --split 2 --remainder 0
+ ```
+
+* On the second GPU:
+
+ ```bash
+ python search.py --dataroot database/edges2shoes-r \
+ --restore_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_G.pth \
+ --output_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/pkls/result1.pkl \
+ --ngf 64 --batch_size 32 \
+ --config_set channels-64-pix2pix \
+ --real_stat_path real_stat/edges2shoes-r_B.npz --load_in_memory --budget 6.5e9 \
+ --split 2 --remainder 1 --gpu_ids 1
+ ```
+
+Then you could merge the search results with [merge.py](../merge.py)
+
+```bash
+python merge.py --in_dir logs/pix2pix/edges2shoes-r_lite/supernet-stage2/pkls \
+ --output_path logs/cycle_gan/horse2zebra/supernet
+```
+
+Once you get the search results, you could use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want.
+
+```shell
+python select_arch.py --macs 6.5e9 --fid 32 \
+ --pkl_path logs/pix2pix/edges2shoes-r/supernet/result.pkl
+```
+
+##### Fine-tuning the Best Model
+
+(Optional) Fine-tune a specific subnet within the pre-trained "once-for-all" network. To further improve the performance of your chosen subnet, you may need to fine-tune the subnet. For example, if you want to fine-tune a subnet within the "once-for-all" network with `'config_str': 32_32_48_40_64_40_16_32`, use the following command:
+
+```shell
+bash scripts/pix2pix/edges2shoes-r_lite/finetune.sh 32_32_48_40_64_40_16_32
+```
+
+##### Export the Model
+
+Extract a subnet from the "once-for-all" network. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `32_32_48_32_48_48_16_16`, then you can export the model by this command:
+
+```shell
+bash scripts/pix2pix/edges2shoes-r_lite/export.sh 32_32_48_40_64_40_16_32
+```
+
+## CycleGAN Model Compression
+
+The pipeline is almost identical to pix2pix. We will show the pipeline on `horse2zebra` dataset.
+
+##### Train a MobileNet Teacher Model
+
+Train a MobileNet-style teacher model from scratch.
+
+```shell
+bash scripts/cycle_gan/horse2zebra_lite/train_mobile.sh
+```
+
+We provide a pre-trained teacher model for each dataset. You could download the model using
+
+```shell
+python scripts/download_model.py --model cycle_gan --task horse2zebra_lite --stage mobile
+```
+
+and test the model by
+
+```shell
+bash scripts/cycle_gan/horse2zebra_lite/test_mobile.sh
+```
+
+##### "Once-for-all" Network Training
+
+Train a "once-for-all" network from a pre-trained student model to search for the efficient architectures.
+
+```shell
+bash scripts/cycle_gan/horse2zebra_lite/train_supernet.sh
+```
+
+We provide a pre-trained once-for-all network for each dataset. You could download the model by
+
+```shell
+python scripts/download_model.py --model cycle_gan --task horse2zebra_lite --stage supernet
+```
+
+##### Select the Best Model
+
+Evaluate all the candidate sub-networks given a specific configuration
+
+```shell
+bash scripts/cycle_gan/horse2zebra_lite/search.sh
+```
+To support multi-GPU search, you could manually split the search space with additional arguments `--split` and `--remainder` and merge them with [merge.py](../merge.py), which is the same as pix2pix.
+
+You could also use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want. The usage is the same as pix2pix.
+
+##### Fine-tuning the Best Model
+
+During our experiments, we observe that fine-tuning the model on horse2zebra increases FID. **You may skip the fine-tuning.**
+
+##### Export the Model
+
+Extract a subnet from the supernet. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `24_16_32_16_32_64_16_24`, then you can export the model by this command:
+
+```shell
+bash scripts/cycle_gan/horse2zebra_lite/export.sh 24_16_32_16_32_64_16_24
+```
diff --git a/docs/training_tutorial.md b/docs/training_tutorial.md
index 9417253..f76916b 100644
--- a/docs/training_tutorial.md
+++ b/docs/training_tutorial.md
@@ -13,9 +13,9 @@ Please refer to our [README](../README.md) for the installation, dataset prepara
### Pipeline
-Below we show the full pipeline for compressing pix2pix and cycleGAN models. **We provide pre-trained models after each step. You could use the pretrained models to skip some steps.** For more training details, please refer to [Appendix 6.1 Complete Pipeline](https://arxiv.org/pdf/2003.08936.pdf) of our paper.
+Below we show the full pipeline for compressing pix2pix, cycleGAN and GauGAN models. **We provide pre-trained models after each step. You could use the pretrained models to skip some steps.** For more training details, please refer to [Appendix 6.1 Complete Pipeline](https://arxiv.org/pdf/2003.08936.pdf) of our paper.
-In fact, several steps including "Train a MobileNet Teacher Model", "Pre-distillation", and "Fine-tuning the Best Model" may be omitted from the whole pipeline. We will provide a simplified pipeline soon.
+In fact, several steps including "Train a MobileNet Teacher Model", "Pre-distillation", and "Fine-tuning the Best Model" may be omitted from the whole pipeline. Please check [simplified_pipeline.md](./simplified_pipeline.md) for our simplified pipeline.
## Pix2pix Model Compression
@@ -94,7 +94,40 @@ such as
`'config_str'` is a channel configuration description to identify a specific subnet within the "once-for-all" network.
-You could use our auxiliary script `select_arch.py` to select the architecture you want.
+To accelerate the search process, you may need to want to search the sub-networks on multiple GPUs. You could manually split the search space with [search.py](../search.py). All you need to do is add additional arguments `--split` and `--remainder`. For example, if you need to search the sub-networks with 2 GPUs, you could use the following commands:
+
+* On the first GPU:
+
+ ```bash
+ python search.py --dataroot database/edges2shoes-r \
+ --restore_G_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_G.pth \
+ --output_path logs/pix2pix/edges2shoes-r/supernet/result.pkl \
+ --ngf 48 --batch_size 32 \
+ --config_set channels-48 \
+ --real_stat_path real_stat/edges2shoes-r_B.npz \
+ --split 2 --remainder 0
+ ```
+
+* On the second GPU:
+
+ ```bash
+ python search.py --dataroot database/edges2shoes-r \
+ --restore_G_path logs/pix2pix/edges2shoes-r/supernet/checkpoints/latest_net_G.pth \
+ --output_path logs/pix2pix/edges2shoes-r/supernet/result.pkl \
+ --ngf 48 --batch_size 32 \
+ --config_set channels-48 \
+ --real_stat_path real_stat/edges2shoes-r_B.npz \
+ --split 2 --remainder 1 --gpu_ids 1
+ ```
+
+Then you could merge the search results with [merge.py](../merge.py)
+
+```bash
+python merge.py --in_dir logs/pix2pix/edges2shoes-r_lite/supernet-stage2/pkls \
+ --output_path logs/cycle_gan/horse2zebra/supernet
+```
+
+You could use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want.
```shell
python select_arch.py --macs 5.7e9 --fid 30 \ # macs <= 5.7e9(10x), fid >= 30
@@ -111,7 +144,7 @@ bash scripts/pix2pix/edges2shoes-r/finetune.sh 32_32_48_32_48_48_16_16
##### Export the Model
-Extract a subnet from the "once-for-all" network. We provide a code `export.py` to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `32_32_48_32_48_48_16_16`, then you can export the model by this command:
+Extract a subnet from the "once-for-all" network. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `32_32_48_32_48_48_16_16`, then you can export the model by this command:
```shell
bash scripts/pix2pix/edges2shoes-r/export.sh 32_32_48_32_48_48_16_16
@@ -186,7 +219,9 @@ Evaluate all the candidate sub-networks given a specific configuration
```shell
bash scripts/cycle_gan/horse2zebra/search.sh
```
-You could also use our auxiliary script `select_arch.py` to select the architecture you want. The usage is the same as pix2pix.
+To support multi-GPU search, you could manually split the search space with additional arguments `--split` and `--remainder` and merge them with [merge.py](../merge.py), which is the same as pix2pix.
+
+You could also use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want. The usage is the same as pix2pix.
##### Fine-tuning the Best Model
@@ -200,7 +235,7 @@ During our experiments, we observe that fine-tuning the model on horse2zebra inc
##### Export the Model
-Extract a subnet from the supernet. We provide a code `export.py` to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `16_16_32_16_32_32_16_16`, then you can export the model by this command:
+Extract a subnet from the supernet. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `16_16_32_16_32_32_16_16`, then you can export the model by this command:
```shell
bash scripts/cycle_gan/horse2zebra/export.sh 16_16_32_16_32_32_16_16
@@ -270,21 +305,23 @@ Evaluate all the candidate sub-networks given a specific configuration (e.g., MA
```shell
bash scripts/gaugan/cityscapes/search.sh
```
-You could also use our auxiliary script `select_arch.py` to select the architecture you want. The usage is the same as pix2pix.
+To support multi-GPU search, you could manually split the search space with additional arguments `--split` and `--remainder` and merge them with [merge.py](../merge.py), which is the same as pix2pix.
+
+You could also use our auxiliary script [select_arch.py](../select_arch.py) to select the architecture you want. The usage is the same as pix2pix.
##### Fine-tuning the Best Model
(Optional) Fine-tune a specific subnet within the pre-trained "once-for-all" network. For example, if you want to fine-tune a subnet within the "once-for-all" network with `'config_str': 32_32_48_32_48_48_16_16`, try the following command:
```shell
-bash scripts/gaugan/cityscapes/finetune.sh 16_16_32_16_32_32_16_16
+bash scripts/gaugan/cityscapes/finetune.sh 32_40_40_32_40_24_32_24
```
##### Export the Model
-Extract a subnet from the supernet. We provide a code `export.py` to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `16_16_32_16_32_32_16_16`, then you can export the model by this command:
+Extract a subnet from the supernet. We provide a code [export.py](../export.py) to extract a specific subnet according to a configuration description. For example, if the `config_str` of your chosen subnet is `32_40_40_32_40_24_32_24`, then you can export the model by this command:
```shell
-bash scripts/gaugan/cityscapes/export.sh 16_16_32_16_32_32_16_16
+bash scripts/gaugan/cityscapes/export.sh 32_40_40_32_40_24_32_24
```
diff --git a/scripts/cycle_gan/horse2zebra_lite/export.sh b/scripts/cycle_gan/horse2zebra_lite/export.sh
new file mode 100644
index 0000000..b90c746
--- /dev/null
+++ b/scripts/cycle_gan/horse2zebra_lite/export.sh
@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+python export.py \
+ --input_path logs/cycle_gan/horse2zebra_lite/supernet-stage2/checkpoints/latest_net_G.pth \
+ --output_path logs/cycle_gan/horse2zebra_lite/compressed/latest_net_G.pth \
+ --ngf 64 --config_str $1
diff --git a/scripts/cycle_gan/horse2zebra_lite/latency_compressed.sh b/scripts/cycle_gan/horse2zebra_lite/latency_compressed.sh
new file mode 100644
index 0000000..902af05
--- /dev/null
+++ b/scripts/cycle_gan/horse2zebra_lite/latency_compressed.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+python latency.py --dataroot database/horse2zebra/valA \
+ --dataset_mode single \
+ --results_dir results-pretrained/cycle_gan/horse2zebra_lite/compressed \
+ --config_str 24_16_32_16_32_64_16_24 \
+ --restore_G_path pretrained/cycle_gan/horse2zebra_lite/compressed/latest_net_G.pth \
+ --need_profile \
+ --real_stat_path real_stat/horse2zebra_B.npz
diff --git a/scripts/cycle_gan/horse2zebra_lite/search.sh b/scripts/cycle_gan/horse2zebra_lite/search.sh
new file mode 100644
index 0000000..caf461b
--- /dev/null
+++ b/scripts/cycle_gan/horse2zebra_lite/search.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+python search.py --dataroot database/horse2zebra/valA \
+ --dataset_mode single \
+ --restore_G_path logs/cycle_gan/horse2zebra_lite/supernet-stage2/checkpoints/latest_net_G.pth \
+ --output_path logs/cycle_gan/horse2zebra_lite/supernet-stage2/result.pkl \
+ --ngf 64 --batch_size 32 \
+ --config_set channels-64-cycleGAN \
+ --real_stat_path real_stat/horse2zebra_B.npz --budget 3.6e9
diff --git a/scripts/cycle_gan/horse2zebra_lite/test_compressed.sh b/scripts/cycle_gan/horse2zebra_lite/test_compressed.sh
new file mode 100644
index 0000000..58ccb04
--- /dev/null
+++ b/scripts/cycle_gan/horse2zebra_lite/test_compressed.sh
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+python test.py --dataroot database/horse2zebra/valA \
+ --dataset_mode single \
+ --results_dir results-pretrained/cycle_gan/horse2zebra_lite/compressed \
+ --config_str 24_16_32_16_32_64_16_24 \
+ --restore_G_path pretrained/cycle_gan/horse2zebra_lite/compressed/latest_net_G.pth \
+ --need_profile \
+ --real_stat_path real_stat/horse2zebra_B.npz
diff --git a/scripts/cycle_gan/horse2zebra_lite/train_mobile.sh b/scripts/cycle_gan/horse2zebra_lite/train_mobile.sh
new file mode 100644
index 0000000..3769974
--- /dev/null
+++ b/scripts/cycle_gan/horse2zebra_lite/train_mobile.sh
@@ -0,0 +1,6 @@
+#!/usr/bin/env bash
+python train.py --dataroot database/horse2zebra \
+ --model cycle_gan \
+ --log_dir logs/cycle_gan/horse2zebra_lite/mobile \
+ --real_stat_A_path real_stat/horse2zebra_A.npz \
+ --real_stat_B_path real_stat/horse2zebra_B.npz
diff --git a/scripts/cycle_gan/horse2zebra_lite/train_supernet.sh b/scripts/cycle_gan/horse2zebra_lite/train_supernet.sh
new file mode 100644
index 0000000..6e3bfbc
--- /dev/null
+++ b/scripts/cycle_gan/horse2zebra_lite/train_supernet.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+python train_supernet.py --dataroot database/horse2zebra \
+ --dataset_mode unaligned \
+ --supernet resnet \
+ --log_dir logs/cycle_gan/horse2zebra_lite/supernet-stage1 \
+ --gan_mode lsgan \
+ --student_ngf 64 --ndf 64 \
+ --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \
+ --restore_student_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \
+ --restore_D_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/checkpoints/latest_net_D.pth \
+ --real_stat_path real_stat/horse2zebra_B.npz \
+ --lambda_recon 10 --lambda_distill 0.01 \
+ --nepochs 50 --nepochs_decay 50 \
+ --save_epoch_freq 20 \
+ --config_set channels-64-cycleGAN-stage1 --sort_channels
+python train_supernet.py --dataroot database/horse2zebra \
+ --dataset_mode unaligned \
+ --supernet resnet \
+ --log_dir logs/cycle_gan/horse2zebra_lite/supernet-stage2 \
+ --gan_mode lsgan \
+ --student_ngf 64 --ndf 64 \
+ --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \
+ --restore_student_G_path logs/cycle_gan/horse2zebra_lite/supernet-stage1/checkpoints/latest_net_G.pth \
+ --restore_D_path logs/cycle_gan/horse2zebra_lite/supernet-stage2/checkpoints/latest_net_D.pth \
+ --real_stat_path real_stat/horse2zebra_B.npz \
+ --lambda_recon 10 --lambda_distill 0.01 \
+ --nepochs 200 --nepochs_decay 200 \
+ --save_epoch_freq 20 \
+ --config_set channels-64-cycleGAN
diff --git a/scripts/download_model.py b/scripts/download_model.py
index cc9ca22..5e8d4a9 100644
--- a/scripts/download_model.py
+++ b/scripts/download_model.py
@@ -6,12 +6,11 @@
def check(opt):
if opt.model == 'pix2pix':
- assert opt.task in ['edges2shoes-r', 'map2sat', 'cityscapes']
+ assert opt.task in ['edges2shoes-r', 'map2sat', 'cityscapes', 'edges2shoes-r_lite']
elif opt.model == 'cycle_gan':
- assert opt.task in ['horse2zebra']
+ assert opt.task in ['horse2zebra', 'horse2zebra_lite']
elif opt.model == 'gaugan':
assert opt.task in ['cityscapes']
- assert opt.stage in ['compressed', 'full']
else:
raise NotImplementedError('Unsupported model [%s]!' % opt.model)
diff --git a/scripts/pix2pix/edges2shoes-r/export.sh b/scripts/pix2pix/edges2shoes-r/export.sh
index aca2550..e67d7ee 100644
--- a/scripts/pix2pix/edges2shoes-r/export.sh
+++ b/scripts/pix2pix/edges2shoes-r/export.sh
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
python export.py \
--input_path logs/pix2pix/edges2shoes-r/finetune/checkpoints/latest_net_G.pth \
- --output_path logs/pix2pix/edges2shoes-r/compressed/compressed_net_G.pth \
+ --output_path logs/pix2pix/edges2shoes-r/compressed/latest_net_G.pth \
--config_str $1
diff --git a/scripts/pix2pix/edges2shoes-r_lite/export.sh b/scripts/pix2pix/edges2shoes-r_lite/export.sh
new file mode 100644
index 0000000..f840677
--- /dev/null
+++ b/scripts/pix2pix/edges2shoes-r_lite/export.sh
@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+python export.py \
+ --input_path logs/pix2pix/edges2shoes-r_lite/finetune/checkpoints/latest_net_G.pth \
+ --output_path logs/pix2pix/edges2shoes-r_lite/compressed/latest_net_G.pth \
+ --ngf 64 --config_str $1
diff --git a/scripts/pix2pix/edges2shoes-r_lite/finetune.sh b/scripts/pix2pix/edges2shoes-r_lite/finetune.sh
new file mode 100644
index 0000000..e8addb2
--- /dev/null
+++ b/scripts/pix2pix/edges2shoes-r_lite/finetune.sh
@@ -0,0 +1,12 @@
+#!/usr/bin/env bash
+python train_supernet.py --dataroot database/edges2shoes-r \
+ --supernet resnet \
+ --log_dir logs/pix2pix/edges2shoes-r_lite/finetune \
+ --batch_size 4 \
+ --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \
+ --restore_student_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_G.pth \
+ --restore_D_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_D.pth \
+ --real_stat_path real_stat/edges2shoes-r_B.npz \
+ --nepochs 5 --nepochs_decay 15 \
+ --teacher_ngf 64 --student_ngf 64 \
+ --config_str $1
diff --git a/scripts/pix2pix/edges2shoes-r_lite/latency_compressed.sh b/scripts/pix2pix/edges2shoes-r_lite/latency_compressed.sh
new file mode 100644
index 0000000..43e36f1
--- /dev/null
+++ b/scripts/pix2pix/edges2shoes-r_lite/latency_compressed.sh
@@ -0,0 +1,7 @@
+#!/usr/bin/env bash
+python latency.py --dataroot database/edges2shoes-r \
+ --results_dir results-pretrained/pix2pix/edges2shoes-r_lite/compressed \
+ --restore_G_path pretrained/pix2pix/edges2shoes-r_lite/compressed/latest_net_G.pth \
+ --config_str 32_32_48_40_64_40_16_32 \
+ --real_stat_path real_stat/edges2shoes-r_B.npz \
+ --need_profile --num_test 500
diff --git a/scripts/pix2pix/edges2shoes-r_lite/search.sh b/scripts/pix2pix/edges2shoes-r_lite/search.sh
new file mode 100644
index 0000000..17b2e00
--- /dev/null
+++ b/scripts/pix2pix/edges2shoes-r_lite/search.sh
@@ -0,0 +1,7 @@
+#!/usr/bin/env bash
+python search.py --dataroot database/edges2shoes-r \
+ --restore_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/checkpoints/latest_net_G.pth \
+ --output_path logs/pix2pix/edges2shoes-r_lite/supernet-stage2/result.pkl \
+ --ngf 64 --batch_size 32 \
+ --config_set channels-64-pix2pix \
+ --real_stat_path real_stat/edges2shoes-r_B.npz --load_in_memory --budget 6.5e9
diff --git a/scripts/pix2pix/edges2shoes-r_lite/test_compressed.sh b/scripts/pix2pix/edges2shoes-r_lite/test_compressed.sh
new file mode 100644
index 0000000..879d13d
--- /dev/null
+++ b/scripts/pix2pix/edges2shoes-r_lite/test_compressed.sh
@@ -0,0 +1,7 @@
+#!/usr/bin/env bash
+python test.py --dataroot database/edges2shoes-r \
+ --results_dir results-pretrained/pix2pix/edges2shoes-r_lite/compressed \
+ --restore_G_path pretrained/pix2pix/edges2shoes-r_lite/compressed/latest_net_G.pth \
+ --config_str 32_32_48_40_64_40_16_32 \
+ --real_stat_path real_stat/edges2shoes-r_B.npz \
+ --need_profile --num_test 500
diff --git a/scripts/pix2pix/edges2shoes-r_lite/train_mobile.sh b/scripts/pix2pix/edges2shoes-r_lite/train_mobile.sh
new file mode 100644
index 0000000..86268e5
--- /dev/null
+++ b/scripts/pix2pix/edges2shoes-r_lite/train_mobile.sh
@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+python train.py --dataroot database/edges2shoes-r \
+ --model pix2pix \
+ --log_dir logs/pix2pix/edges2shoes-r_lite/mobile \
+ --real_stat_path real_stat/edges2shoes-r_B.npz
diff --git a/scripts/pix2pix/edges2shoes-r_lite/train_supernet.sh b/scripts/pix2pix/edges2shoes-r_lite/train_supernet.sh
new file mode 100644
index 0000000..b7c1335
--- /dev/null
+++ b/scripts/pix2pix/edges2shoes-r_lite/train_supernet.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+python train_supernet.py --dataroot database/edges2shoes-r \
+ --supernet resnet \
+ --log_dir logs/pix2pix/edges2shoes-r_lite/supernet-stage1 \
+ --batch_size 4 \
+ --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \
+ --restore_student_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \
+ --restore_D_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \
+ --real_stat_path real_stat/edges2shoes-r_B.npz \
+ --nepochs 2 --nepochs_decay 8 \
+ --teacher_ngf 64 --student_ngf 64 \
+ --config_set channels-64-pix2pix-stage1 --sort_channels
+
+python train_supernet.py --dataroot database/edges2shoes-r \
+ --supernet resnet \
+ --log_dir logs/pix2pix/edges2shoes-r_lite/supernet-stage2 \
+ --batch_size 4 \
+ --restore_teacher_G_path logs/pix2pix/edges2shoes-r_lite/mobile/checkpoints/latest_net_G.pth \
+ --restore_student_G_path logs/pix2pix/edges2shoes-r_lite/supernet-stage1/checkpoints/latest_net_G.pth \
+ --restore_D_path logs/pix2pix/edges2shoes-r_lite/supernet-stage1/checkpoints/latest_net_G.pth \
+ --real_stat_path real_stat/edges2shoes-r_B.npz \
+ --nepochs 10 --nepochs_decay 30 \
+ --teacher_ngf 64 --student_ngf 64 \
+ --config_set channels-64-pix2pix
diff --git a/supernets/resnet_supernet.py b/supernets/resnet_supernet.py
index 5d1ab49..0defae9 100644
--- a/supernets/resnet_supernet.py
+++ b/supernets/resnet_supernet.py
@@ -2,8 +2,8 @@
import os
import torch
-import torch.nn.functional as F
from torch import nn
+from torch.nn import functional as F
from tqdm import tqdm
from configs import decode_config
@@ -13,6 +13,7 @@
from metric import get_fid, get_mIoU
from models.modules.super_modules import SuperConv2d
from utils import util
+from utils.weight_transfer import load_pretrained_weight
class ResnetSupernet(BaseResnetDistiller):
@@ -20,6 +21,8 @@ class ResnetSupernet(BaseResnetDistiller):
def modify_commandline_options(parser, is_train):
assert is_train
parser = super(ResnetSupernet, ResnetSupernet).modify_commandline_options(parser, is_train)
+ parser.add_argument('--sort_channels', action='store_true',
+ help='whether to sort the channels of student G by L1 norm')
parser.set_defaults(norm='instance', student_netG='super_mobile_resnet_9blocks',
dataset_mode='aligned', log_dir='logs/supernet')
return parser
@@ -170,3 +173,11 @@ def evaluate_model(self, step):
def test(self, config):
with torch.no_grad():
self.forward(config)
+
+ def load_networks(self, verbose=True):
+ super(ResnetSupernet, self).load_networks()
+ if hasattr(self, 'netG_student_tmp'):
+ load_pretrained_weight(self.opt.student_netG.replace('super_', ''), self.opt.student_netG,
+ self.netG_student_tmp, self.netG_student,
+ self.opt.student_ngf, self.opt.student_ngf)
+ del self.netG_student_tmp
diff --git a/utils/weight_transfer.py b/utils/weight_transfer.py
index 38d54c0..63f2d08 100644
--- a/utils/weight_transfer.py
+++ b/utils/weight_transfer.py
@@ -3,12 +3,14 @@
from models.modules.mobile_modules import SeparableConv2d
from models.modules.resnet_architecture.mobile_resnet_generator import MobileResnetBlock
from models.modules.resnet_architecture.resnet_generator import ResnetBlock
+from models.modules.resnet_architecture.super_mobile_resnet_generator import SuperMobileResnetBlock
from models.modules.spade_architecture.mobile_spade_generator import MobileSPADEGenerator, MobileSPADEResnetBlock, \
MobileSPADE
+from models.modules.super_modules import SuperConv2d, SuperConvTranspose2d, SuperSeparableConv2d
def transfer_Conv2d(m1, m2, input_index=None, output_index=None):
- assert isinstance(m1, nn.Conv2d) and isinstance(m2, nn.Conv2d)
+ assert isinstance(m1, nn.Conv2d) and isinstance(m2, (nn.Conv2d, SuperConv2d))
if m1.out_channels == 3: # If this is the last convolution
assert input_index is not None
m2.weight.data = m1.weight.data[:, input_index].clone()
@@ -42,7 +44,7 @@ def transfer_Conv2d(m1, m2, input_index=None, output_index=None):
def transfer_ConvTranspose2d(m1, m2, input_index=None, output_index=None):
- assert isinstance(m1, nn.ConvTranspose2d) and isinstance(m2, nn.ConvTranspose2d)
+ assert isinstance(m1, nn.ConvTranspose2d) and isinstance(m2, (nn.ConvTranspose2d, SuperConvTranspose2d))
assert output_index is None
p = m1.weight.data
if input_index is None:
@@ -60,7 +62,7 @@ def transfer_ConvTranspose2d(m1, m2, input_index=None, output_index=None):
def transfer_SeparableConv2d(m1, m2, input_index=None, output_index=None):
- assert isinstance(m1, SeparableConv2d) and isinstance(m2, SeparableConv2d)
+ assert isinstance(m1, SeparableConv2d) and isinstance(m2, (SeparableConv2d, SuperSeparableConv2d))
dw1, pw1 = m1.conv[0], m1.conv[2]
dw2, pw2 = m2.conv[0], m2.conv[2]
@@ -77,7 +79,7 @@ def transfer_SeparableConv2d(m1, m2, input_index=None, output_index=None):
def transfer_MobileResnetBlock(m1, m2, input_index=None, output_index=None):
- assert isinstance(m1, MobileResnetBlock) and isinstance(m2, MobileResnetBlock)
+ assert isinstance(m1, MobileResnetBlock) and isinstance(m2, (MobileResnetBlock, SuperMobileResnetBlock))
assert output_index is None
idxs = transfer(m1.conv_block[1], m2.conv_block[1], input_index=input_index)
idxs = transfer(m1.conv_block[6], m2.conv_block[6], input_index=idxs, output_index=input_index)
@@ -125,7 +127,6 @@ def transfer_MobileSPADEResnetBlock(m1, m2, input_index=None, output_index=None)
def transfer(m1, m2, input_index=None, output_index=None):
- assert type(m1) == type(m2)
if isinstance(m1, nn.Conv2d):
return transfer_Conv2d(m1, m2, input_index, output_index)
elif isinstance(m1, nn.ConvTranspose2d):
@@ -145,7 +146,6 @@ def transfer(m1, m2, input_index=None, output_index=None):
def load_pretrained_weight(model1, model2, netA, netB, ngf1, ngf2):
- assert model1 == model2
assert ngf1 >= ngf2
if isinstance(netA, nn.DataParallel):
@@ -162,7 +162,7 @@ def load_pretrained_weight(model1, model2, netA, netB, ngf1, ngf2):
assert len(net1.model) == len(net2.model)
for i in range(28):
m1, m2 = net1.model[i], net2.model[i]
- assert type(m1) == type(m2)
+ # assert type(m1) == type(m2)
if isinstance(m1, (nn.Conv2d, nn.ConvTranspose2d, MobileResnetBlock)):
index = transfer(m1, m2, index)
elif model1 == 'resnet_9blocks':