-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
sotabench.py
71 lines (58 loc) · 2.04 KB
/
sotabench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import numpy as np
import PIL
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import ImageNet
from efficientnet_pytorch import EfficientNet
from sotabencheval.image_classification import ImageNetEvaluator
from sotabencheval.utils import is_server
if is_server():
DATA_ROOT = DATA_ROOT = os.environ.get('IMAGENET_DIR', './imagenet') # './.data/vision/imagenet'
else: # local settings
DATA_ROOT = os.environ['IMAGENET_DIR']
assert bool(DATA_ROOT), 'please set IMAGENET_DIR environment variable'
print('Local data root: ', DATA_ROOT)
model_name = 'EfficientNet-B5'
model = EfficientNet.from_pretrained(model_name.lower())
image_size = EfficientNet.get_image_size(model_name.lower())
input_transform = transforms.Compose([
transforms.Resize(image_size, PIL.Image.BICUBIC),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_dataset = ImageNet(
DATA_ROOT,
split="val",
transform=input_transform,
target_transform=None,
)
test_loader = DataLoader(
test_dataset,
batch_size=128,
shuffle=False,
num_workers=4,
pin_memory=True,
)
model = model.cuda()
model.eval()
evaluator = ImageNetEvaluator(model_name=model_name,
paper_arxiv_id='1905.11946')
def get_img_id(image_name):
return image_name.split('/')[-1].replace('.JPEG', '')
with torch.no_grad():
for i, (input, target) in enumerate(test_loader):
input = input.to(device='cuda', non_blocking=True)
target = target.to(device='cuda', non_blocking=True)
output = model(input)
image_ids = [get_img_id(img[0]) for img in test_loader.dataset.imgs[i*test_loader.batch_size:(i+1)*test_loader.batch_size]]
evaluator.add(dict(zip(image_ids, list(output.cpu().numpy()))))
if evaluator.cache_exists:
break
if not is_server():
print("Results:")
print(evaluator.get_results())
evaluator.save()