-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataLoad.py
68 lines (44 loc) · 1.87 KB
/
dataLoad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import torch
class ED_dataset(Dataset):
def __init__(self, filepath, which):
self.filepath= filepath
self.which= which
def __len__(self):
if (self.which=='train'):
return 120*16*25
if (self.which=='val'):
return 30*100*25
return 0
def __getitem__(self, idx):
folder= idx//400+120
image= (idx%400)//25
crop= (idx%25)
blpath= self.filepath+'/'+self.which+'/'+self.which+'_blur_proc/'+'s'+str(folder)+'/'+str(image)+'_crop'+str(crop)+'.npy'
shpath= self.filepath+'/'+self.which+'/'+self.which+'_shar_proc/'+'s'+str(folder)+'/'+str(image)+'_crop'+str(crop)+'.npy'
blur= np.load(blpath)
shar= np.load(shpath)
return blur, shar
class RNN_dataset(Dataset):
def __init__(self, filepath, which):
self.filepath= filepath
self.which= which
def __len__(self):
if (self.which=='train'):
return 240*25
if (self.which=='val'):
return 30*25
return 0
def __getitem__(self, idx):
folder= idx//25
crop= idx%25
blur= np.zeros((100, 3, 144, 256))
shar= np.zeros((100, 3, 144, 256))
for i in range(100):
blpath= self.filepath+'/'+self.which+'/'+self.which+'_blur_proc/'+'s'+str(folder)+'/'+str(image)+'_crop'+str(crop)+'.npy'
shpath= self.filepath+'/'+self.which+'/'+self.which+'_shar_proc/'+'s'+str(folder)+'/'+str(image)+'_crop'+str(crop)+'.npy'
blur[i, :, :, :]= np.load(blpath)[6:9,:,:]
shar[i, :, :, :]= np.load(shpath)
return blur, shar