-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_eval.py
129 lines (97 loc) · 4.87 KB
/
run_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import pickle
from collections import defaultdict
from itertools import product
from pathlib import Path
from typing import Optional
import numpy as np
import jax
from jax import jit
import jax.numpy as jnp
from jax.image import resize
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
from args import parser
from data import ImageFolder
from model import build_thera
from utils import make_grid, compute_metrics
from vendor.matlab_bicubic import imresize as matlab_imresize
MEAN = np.array([.4488, .4371, .4040])
VAR = np.array([.25, .25, .25])
MAX_PATCH_SIZE = 256
def prepare_batch(target, scale):
target = jnp.asarray(target)
target = target.transpose((0, 2, 3, 1))
source_h, source_w = int(target.shape[1] / scale), int(target.shape[2] / scale)
target = target[:, :source_h * scale, :source_w * scale]
target_t = jnp.float32(scale**(-2))[None]
source = matlab_imresize(target[0], output_shape=(source_h, source_w))[None]
source_up = resize(source, target.shape, 'nearest')
source = jax.nn.standardize(source, mean=MEAN, variance=VAR)
return source, source_up, target_t, target
def evaluate(val_loader, model, params, scale, border_crop,
do_ensemble, save_dir: Optional[Path] = None, y_only=False):
apply_encoder = jit(model.apply_encoder)
apply_decoder = jit(model.apply_decoder)
metrics = defaultdict(list)
for i_img, target in enumerate(tqdm(val_loader)):
source, source_up, target_t, target = prepare_batch(target, scale)
# memory scales in patch_size * scale, so we keep that factor constant
patch_size = MAX_PATCH_SIZE // scale
if patch_size > min(source.shape[1:3]):
patch_size = min(source.shape[1:3])
target_coords = jnp.tile(make_grid(patch_size * scale), (target.shape[0], 1, 1, 1))
outs = []
for i_rot in range(4 if do_ensemble else 1):
source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
source_up_ = jnp.rot90(source_up, k=i_rot, axes=(-3, -2))
encoding = apply_encoder(params, source_)
assert encoding.shape[:-1] == source_.shape[:-1]
num_patches_h = (source_.shape[1] // patch_size) + 1
num_patches_w = (source_.shape[2] // patch_size) + 1
out = np.full_like(source_up_, np.nan, dtype=np.float32)
for i, j in product(range(num_patches_h), range(num_patches_w)):
h_min = min(i * patch_size, source_.shape[1] - patch_size)
h_max = min((i + 1) * patch_size, source_.shape[1])
w_min = min(j * patch_size, source_.shape[2] - patch_size)
w_max = min((j + 1) * patch_size, source_.shape[2])
encoding_p = encoding[:, h_min:h_max, w_min:w_max, :]
out_p = apply_decoder(params, encoding_p, target_coords, target_t)
out[:, scale * h_min:scale * h_max, scale * w_min:scale * w_max, :] = out_p
assert not np.isnan(out).any()
out = out * np.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
out += source_up_
outs.append(np.rot90(out, k=i_rot, axes=(-2, -3)))
out = np.stack(outs).mean(0).clip(0., 1.)
if save_dir is not None:
if not save_dir.exists():
save_dir.mkdir(parents=True, exist_ok=True)
Image.fromarray(np.rint(np.array(out[0] * 255)).astype(np.uint8))\
.save(save_dir / f'{i_img}.png')
s = border_crop
batch_metrics = compute_metrics(
out[:, s:-s, s:-s], target[:, s:-s, s:-s], compute_ssim=True, y_only=y_only)
for k, v in batch_metrics.items():
metrics[k] += [v.item()]
return {k: np.mean(v) for k, v in metrics.items()}
def main(args):
data_sets = [ImageFolder(Path(args.data_dir) / s, transforms.ToTensor(), in_memory=False)
for s in args.eval_sets]
data_loaders = [DataLoader(s, batch_size=1, num_workers=0, shuffle=False) for s in data_sets]
model = build_thera(3, args.backbone, args.size)
with open(args.checkpoint, 'rb') as fh:
params = pickle.load(fh)['model']
for eval_set, data_loader in zip(args.eval_sets, data_loaders):
for scale in args.eval_scales:
border_crop = scale + 6 if 'DIV2K' in eval_set else scale
save_dir = (Path(args.save_dir) / ('ours_' + eval_set + '_' + args.backbone) / str(scale)) \
if args.save_dir else None
metrics = evaluate(data_loader, model, params, scale, border_crop,
not args.no_geo_ensemble, save_dir, args.y_only)
# TODO 2 digits
metrics = {k: np.round(v, 5) for k, v in metrics.items()}
print(f'[{eval_set} x{scale}] ' + ' '.join([f'{k}: {v}' for k, v in metrics.items()]))
if __name__ == '__main__':
args = parser.parse_args()
main(args)