-
Notifications
You must be signed in to change notification settings - Fork 0
/
prid.py
139 lines (117 loc) · 4.59 KB
/
prid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# encoding: utf-8
"""
''
"""
import os
from scipy.io import loadmat
from glob import glob
import random
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.datasets.bases import ImageDataset
import pdb
__all__ = ['PRID', ]
@DATASET_REGISTRY.register()
class PRID(ImageDataset):
dataset_dir = "prid_2011"
dataset_name = 'PRID'
_junk_pids = list(range(201, 750))
def __init__(self, root='datasets', split_id=0, **kwargs):
if isinstance(root, list):
split_id = root[1]
self.root = root[0]
else:
self.root = root
split_id = 0
self.dataset_dir = os.path.join(self.root, self.dataset_dir)
# self.download_dataset(self.dataset_dir, self.dataset_url)
self.cam_a_dir = os.path.join(
self.dataset_dir, 'single_shot', 'cam_a'
)
self.cam_b_dir = os.path.join(
self.dataset_dir, 'single_shot', 'cam_b'
)
self.split_path = os.path.join(self.dataset_dir, 'splits_single_shot.json')
required_files = [
self.dataset_dir,
self.cam_a_dir,
self.cam_b_dir
]
self.check_before_run(required_files)
self.prepare_split()
splits = self.read_json(self.split_path)
if split_id >= len(splits):
raise ValueError(
'split_id exceeds range, received {}, but expected between 0 and {}'
.format(split_id,
len(splits) - 1)
)
split = splits[split_id]
train, query, gallery = self.process_split(split)
super(PRID, self).__init__(train, query, gallery, **kwargs)
def prepare_split(self):
if not os.path.exists(self.split_path):
print('Creating splits ...')
splits = []
for _ in range(10):
# randomly sample 100 IDs for train and use the rest 100 IDs for test
# (note: there are only 200 IDs appearing in both views)
pids = [i for i in range(1, 201)]
train_pids = random.sample(pids, 100)
train_pids.sort()
test_pids = [i for i in pids if i not in train_pids]
split = {'train': train_pids, 'test': test_pids}
splits.append(split)
print('Totally {} splits are created'.format(len(splits)))
self.write_json(splits, self.split_path)
print('Split file is saved to {}'.format(self.split_path))
def process_split(self, split):
train_pids = split['train']
test_pids = split['test']
train_pid2label = {pid: label for label, pid in enumerate(train_pids)}
# train
train = []
for pid in train_pids:
img_name = 'person_' + str(pid).zfill(4) + '.png'
pid = train_pid2label[pid]
#add by hby, for train
pid = self.dataset_name + "_" + str(pid)
img_a_path = os.path.join(self.cam_a_dir, img_name)
camid_0 = self.dataset_name + "_" + str(0)
camid_1 = self.dataset_name + "_" + str(1)
train.append((img_a_path, pid, camid_0))
img_b_path = os.path.join(self.cam_b_dir, img_name)
train.append((img_b_path, pid, camid_1))
# query and gallery
query, gallery = [], []
for pid in test_pids:
img_name = 'person_' + str(pid).zfill(4) + '.png'
img_a_path = os.path.join(self.cam_a_dir, img_name)
query.append((img_a_path, pid, 0))
img_b_path = os.path.join(self.cam_b_dir, img_name)
gallery.append((img_b_path, pid, 1))
for pid in range(201, 750):
img_name = 'person_' + str(pid).zfill(4) + '.png'
img_b_path = os.path.join(self.cam_b_dir, img_name)
gallery.append((img_b_path, pid, 1))
return train, query, gallery
def read_json(self, fpath):
import json
"""Reads json file from a path."""
with open(fpath, 'r') as f:
obj = json.load(f)
return obj
def write_json(self, obj, fpath):
import json
"""Writes to a json file."""
self.mkdir_if_missing(os.path.dirname(fpath))
with open(fpath, 'w') as f:
json.dump(obj, f, indent=4, separators=(',', ': '))
def mkdir_if_missing(self, dirname):
import errno
"""Creates dirname if it is missing."""
if not os.path.exists(dirname):
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != errno.EEXIST:
raise