-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
116 lines (100 loc) · 4.19 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import collections
import glob
import shutil
import sys
from datetime import datetime
from pathlib import Path
import PIL.Image as Image
import torch
from torchvision.transforms import transforms
from tqdm import tqdm
import data_loader.data_loaders as module_data
import model.model as module_arch
from parse_config import ConfigParser
def main(config):
logger = config.get_logger('infernece')
# setup data_loader instances
data_loader = getattr(module_data, config['data_loader']['type'])(
config['data_loader']['args']['data_dir'],
patch_size=config['data_loader']['args']['patch_size'],
batch_size=512,
shuffle=False,
validation_split=0.0,
training=False,
num_workers=2
)
# build model architecture
model = config.init_obj('arch', module_arch)
logger.info(model)
# prepare model for testing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info('Loading checkpoint: {} ...'.format(config.resume))
checkpoint = torch.load(config.resume, map_location=device)
state_dict = checkpoint['state_dict']
if config['n_gpu'] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
with torch.no_grad():
for i, (data, target) in enumerate(tqdm(data_loader)):
data, target = data.to(device), target.to(device)
output = model(data)
batch_size = data_loader.batch_size
patch_idx = torch.arange(
batch_size * i, batch_size * i + data.shape[0])
pred = torch.argmax(output, dim=1)
data_loader.dataset.patches.store_data(
patch_idx, [pred.unsqueeze(1)])
preds = [(data_loader.dataset.patches.combine(idx, data_idx=0).cpu(),
data_loader.dataset.data[idx])
for idx in range(len(data_loader.dataset.data))]
trsfm = transforms.ToPILImage()
out_dir = list(config.save_dir.parts)
out_dir[-3] = 'output'
out_dir = Path(*out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for pred, path in preds:
filename = Path(path).stem + '.png'
pred = trsfm(pred.float())
pred.save(out_dir / filename)
if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
args.add_argument('--data', default=None, type=str,
help='path to data (default: None)')
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
dst_data = Path('.data/', run_id)
data_dir = dst_data / 'test' / 'images'
masks_dir = dst_data / 'test' / 'masks'
data_dir.mkdir(parents=True, exist_ok=True)
masks_dir.mkdir(parents=True, exist_ok=True)
if '--data' in sys.argv:
src_data = Path(sys.argv[sys.argv.index('--data') + 1])
if src_data.is_file():
shutil.copy(src_data, data_dir)
mask = Image.new('1', Image.open(src_data).size)
mask.save(masks_dir / (src_data.stem + '.png'), 'PNG')
else:
for filename in glob.glob('*.jpg', root_dir=src_data):
file_path = src_data / filename
if file_path.is_file():
shutil.copy(file_path, data_dir)
mask = Image.new('1', Image.open(file_path).size)
mask.save(masks_dir / (file_path.stem + '.png'), 'PNG')
sys.argv += ['--data_dir', str(dst_data)]
# custom cli options to modify configuration from default values given in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
CustomArgs(['--data_dir'], type=str,
target='data_loader;args;data_dir')
]
config = ConfigParser.from_args(args, options)
main(config)
shutil.rmtree(dst_data)