Skip to content

Commit

Permalink
release the simplified pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
lmxyy committed Jul 21, 2020
1 parent fca6281 commit 8bbb652
Show file tree
Hide file tree
Showing 24 changed files with 476 additions and 98 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# GAN Compression
### [project](https://hanlab.mit.edu/projects/gancompression/) | [paper](https://arxiv.org/abs/2003.08936) | [demo](https://www.youtube.com/playlist?list=PL80kAHvQbh-r5R8UmXhQK1ndqRvPNw_ex)

**[NEW!]** The simplified pipeline of GAN Compression is released! Check the [tutorial](docs/simplified_pipeline.md) for the pipeline.

**[NEW!]** GauGAN training code and tutorial is released! Check the [tutorial](docs/training_tutorial.md) to compress GauGAN.

**[NEW!]** Correct metric naming and update the evaluation protocol. Support MACs budget for searching.
Expand All @@ -9,8 +11,6 @@

**[NEW!]** The [tutorial](docs/training_tutorial.md) of compression is released! Check the [overview](docs/overview.md) for better understanding our codebase.

**[NEW!]** The PyTorch implementation of a general conditional GAN Compression framework is released.

![teaser](imgs/teaser.png)*We introduce GAN Compression, a general-purpose method for compressing conditional GANs. Our method reduces the computation of widely-used conditional GAN models, including pix2pix, CycleGAN, and GauGAN, by 9-21x while preserving the visual fidelity. Our method is effective for a wide range of generator architectures, learning objectives, and both paired and unpaired settings.*

GAN Compression: Efficient Architectures for Interactive Conditional GANs<br>
Expand Down Expand Up @@ -189,7 +189,7 @@ PyTorch Colab notebook: [CycleGAN](https://colab.research.google.com/github/mit-

### <span id="cityscapes">Cityscapes Dataset</span>

For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script `datasets/prepare_cityscapes_dataset.py` to preprocess it. You need to download `gtFine_trainvaltest.zip` and `leftImg8bit_trainvaltest.zip` and unzip them in the same folder. For example, you may put `gtFine` and `leftImg8bit` in `database/cityscapes-origin`. You need to prepare the dataset with the following commands:
For the Cityscapes dataset, we cannot provide it due to license issue. Please download the dataset from https://cityscapes-dataset.com and use the script [prepare_cityscapes_dataset.py](datasets/prepare_cityscapes_dataset.py) to preprocess it. You need to download `gtFine_trainvaltest.zip` and `leftImg8bit_trainvaltest.zip` and unzip them in the same folder. For example, you may put `gtFine` and `leftImg8bit` in `database/cityscapes-origin`. You need to prepare the dataset with the following commands:

```shell
python datasets/get_trainIds.py database/cityscapes-origin/gtFine/
Expand All @@ -210,7 +210,7 @@ Please refer to our training [tutorial](docs/training_tutorial.md) on how to tra

### FID Computation

To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script `get_real_stat.py` to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:
To compute the FID score, you need to get some statistical information from the groud-truth images of your dataset. We provide a script [get_real_stat.py](./get_real_stat.py) to extract statistical information. For example, for the edges2shoes dataset, you could run the following command:

```shell
python get_real_stat.py \
Expand Down
16 changes: 16 additions & 0 deletions configs/resnet_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,22 @@ def get_configs(config_name):
return ResnetConfigs(n_channels=[[32, 24, 16], [32, 24, 16], [32, 24, 16],
[32, 24, 16], [32, 24, 16], [32, 24, 16],
[32, 24, 16], [32, 24, 16]])
elif config_name == 'channels-64-cycleGAN-stage1':
return ResnetConfigs(n_channels=[[64, 48, 32], [64, 48, 32], [64, 48, 32],
[64, 48, 32], [64, 48, 32], [64, 48, 32],
[64, 48, 32], [64, 48, 32]])
elif config_name == 'channels-64-cycleGAN':
return ResnetConfigs(n_channels=[[64, 48, 32, 24, 16], [64, 48, 32, 24, 16], [64, 48, 32, 24, 16],
[64, 48, 32, 24, 16], [64, 48, 32, 24, 16], [64, 48, 32, 24, 16],
[64, 48, 32, 24, 16], [64, 48, 32, 24, 16]])
elif config_name == 'channels-64-pix2pix-stage1':
return ResnetConfigs(n_channels=[[64, 48], [64, 48], [64, 48],
[64, 48], [64, 48], [64, 48],
[64, 48, 32], [64, 48, 32]])
elif config_name == 'channels-64-pix2pix':
return ResnetConfigs(n_channels=[[64, 48, 32], [64, 48, 32], [64, 48, 40, 32],
[64, 48, 40, 32], [64, 48, 40, 32], [64, 48, 40, 32],
[64, 48, 32, 24, 16], [64, 48, 32, 24, 16]])
elif config_name == 'debug':
return ResnetConfigs(n_channels=[[48, 32], [48, 32], [48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32],
Expand Down
75 changes: 3 additions & 72 deletions configs/spade_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def smallest(self):
ret['channels'].append(min(n_channel))
return ret

def all_configs(self, split=1, remainder=0):
def all_configs(self):

def yield_channels(i):
if i == len(self.n_channels):
Expand All @@ -39,9 +39,8 @@ def yield_channels(i):
for after_channels in yield_channels(i + 1):
yield [n] + after_channels

for i, channels in enumerate(yield_channels(0)):
if i % split == remainder:
yield {'channels': channels}
for channels in yield_channels(0):
yield {'channels': channels}

def __call__(self, name):
assert name in ('largest', 'smallest')
Expand All @@ -64,74 +63,6 @@ def get_configs(config_name):
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels-48-part1':
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[48, 24], [48, 24], [48, 24]])
elif config_name == 'channels-48-part2':
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[48, 24], [48, 24], [40, 32]])
elif config_name == 'channels-48-part3':
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[48, 24], [40, 32], [48, 24]])
elif config_name == 'channels-48-part4':
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[48, 24], [40, 32], [40, 32]])
elif config_name == 'channels-48-part5':
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[40, 32], [48, 24], [48, 24]])
elif config_name == 'channels-48-part6':
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[40, 32], [48, 24], [40, 32]])
elif config_name == 'channels-48-part7':
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[40, 32], [40, 32], [48, 24]])
elif config_name == 'channels-48-part8':
return SPADEConfigs(n_channels=[[48, 40, 32],
[48, 40, 32], [48, 40, 32], [48, 40, 32], [48, 40, 32],
[40, 32], [40, 32], [40, 32]])
elif config_name == 'channels2-48':
return SPADEConfigs(n_channels=[[48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels2-48-part1':
return SPADEConfigs(n_channels=[[48],
[48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels2-48-part2':
return SPADEConfigs(n_channels=[[40],
[48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels2-48-part3':
return SPADEConfigs(n_channels=[[32],
[48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels2-48-part4':
return SPADEConfigs(n_channels=[[24],
[48, 24], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels2-48-part5':
return SPADEConfigs(n_channels=[[48],
[40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels2-48-part6':
return SPADEConfigs(n_channels=[[40],
[40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels2-48-part7':
return SPADEConfigs(n_channels=[[32],
[40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'channels2-48-part8':
return SPADEConfigs(n_channels=[[24],
[40, 32], [48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24],
[48, 40, 32, 24], [48, 40, 32, 24], [48, 40, 32, 24]])
elif config_name == 'debug':
return SPADEConfigs(n_channels=[[48],
[48], [48], [48], [48],
Expand Down
56 changes: 56 additions & 0 deletions datasets/coco_generate_instance_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import argparse
import os

import numpy as np
import skimage.io as io
import tqdm
from pycocotools.coco import COCO
from skimage.draw import polygon

parser = argparse.ArgumentParser()
parser.add_argument('--annotation_file', type=str, default="./annotations/instances_train2017.json",
help="Path to the annocation file. It can be downloaded at http://images.cocodataset.org/annotations/annotations_trainval2017.zip. Should be either instances_train2017.json or instances_val2017.json")
parser.add_argument('--input_label_dir', type=str, default="./train_label/",
help="Path to the directory containing label maps. It can be downloaded at http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip")
parser.add_argument('--output_instance_dir', type=str, default="./train_inst/",
help="Path to the output directory of instance maps")
opt = parser.parse_args()
os.makedirs(opt.output_instance_dir, exist_ok=True)

print("annotation file at {}".format(opt.annotation_file))
print("input label maps at {}".format(opt.input_label_dir))
print("output dir at {}".format(opt.output_instance_dir))

# initialize COCO api for instance annotations
coco = COCO(opt.annotation_file)

# display COCO categories and supercategories
cats = coco.loadCats(coco.getCatIds())
imgIds = coco.getImgIds(catIds=coco.getCatIds(cats))
for ix, id in enumerate(tqdm.tqdm(imgIds)):
# if ix % 50 == 0:
# print("{} / {}".format(ix, len(imgIds)))
img_dict = coco.loadImgs(id)[0]
filename = img_dict["file_name"].replace("jpg", "png")
label_name = os.path.join(opt.input_label_dir, filename)
inst_name = os.path.join(opt.output_instance_dir, filename)
img = io.imread(label_name, as_gray=True)

annIds = coco.getAnnIds(imgIds=id, catIds=[], iscrowd=None)
anns = coco.loadAnns(annIds)
count = 0
for ann in anns:
if type(ann["segmentation"]) == list:
if "segmentation" in ann:
for seg in ann["segmentation"]:
poly = np.array(seg).reshape((int(len(seg) / 2), 2))
rr, cc = polygon(poly[:, 1] - 1, poly[:, 0] - 1)
img[rr, cc] = count
count += 1

io.imsave(inst_name, img)
7 changes: 7 additions & 0 deletions distillers/base_resnet_distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def __init__(self, opt):
opt.student_netG, opt.norm, opt.student_dropout_rate,
opt.init_type, opt.init_gain, self.gpu_ids, opt=opt)

if getattr(opt, 'sort_channels', False) and opt.restore_student_G_path is not None:
self.netG_student_tmp = networks.define_G(opt.input_nc, opt.output_nc, opt.student_ngf,
opt.student_netG.replace('super_', ''), opt.norm,
opt.student_dropout_rate, opt.init_type, opt.init_gain,
self.gpu_ids, opt=opt)
if hasattr(opt, 'distiller'):
self.netG_pretrained = networks.define_G(opt.input_nc, opt.output_nc, opt.pretrained_ngf,
opt.pretrained_netG, opt.norm, 0,
Expand Down Expand Up @@ -221,6 +226,8 @@ def load_networks(self, verbose=True):
util.load_network(self.netG_teacher, self.opt.restore_teacher_G_path, verbose)
if self.opt.restore_student_G_path is not None:
util.load_network(self.netG_student, self.opt.restore_student_G_path, verbose)
if hasattr(self, 'netG_student_tmp'):
util.load_network(self.netG_student_tmp, self.opt.restore_student_G_path, verbose)
if self.opt.restore_D_path is not None:
util.load_network(self.netD, self.opt.restore_D_path, verbose)
if self.opt.restore_A_path is not None:
Expand Down
Loading

0 comments on commit 8bbb652

Please sign in to comment.