-
Notifications
You must be signed in to change notification settings - Fork 95
/
dataloader.py
39 lines (27 loc) · 1004 Bytes
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from torch.utils.data import Dataset
import numpy as np
import cv2
import os
class PredDataset(Dataset):
''' Reads image and trimap pairs from folder.
'''
def __init__(self, img_dir, trimap_dir):
self.img_dir, self.trimap_dir = img_dir, trimap_dir
self.img_names = [x for x in os.listdir(self.img_dir) if 'png' in x]
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_name = self.img_names[idx]
image = read_image(os.path.join(self.img_dir, img_name))
trimap = read_trimap(os.path.join(self.trimap_dir, img_name))
pred_dict = {'image': image, 'trimap': trimap, 'name': img_name}
return pred_dict
def read_image(name):
return (cv2.imread(name) / 255.0)[:, :, ::-1]
def read_trimap(name):
trimap_im = cv2.imread(name, 0) / 255.0
h, w = trimap_im.shape
trimap = np.zeros((h, w, 2))
trimap[trimap_im == 1, 1] = 1
trimap[trimap_im == 0, 0] = 1
return trimap