-
Notifications
You must be signed in to change notification settings - Fork 19
/
DataLoader.lua
146 lines (119 loc) · 3.79 KB
/
DataLoader.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
local dl = require 'dataload._env'
local DataLoader = torch.class('dl.DataLoader', dl)
function DataLoader:index(indices, inputs, targets, ...)
error"Not Implemented"
end
function DataLoader:sample(batchsize, inputs, targets, ...)
self._indices = self._indices or torch.LongTensor()
self._indices:resize(batchsize):random(1,self:size())
return self:index(self._indices, inputs, targets, ...)
end
function DataLoader:sub(start, stop, inputs, targets, ...)
self._indices = self._indices or torch.LongTensor()
self._indices:range(start, stop)
return self:index(self._indices, inputs, targets, ...)
end
function DataLoader:shuffle()
error"Not Implemented"
end
function DataLoader:split(ratio)
error"Not Implemented"
end
-- number of samples
function DataLoader:size()
error"Not Implemented"
end
-- size of inputs
function DataLoader:isize(excludedim)
-- by default, batch dimension is excluded
excludedim = excludedim == nil and 1 or excludedim
error"Not Implemented"
end
-- size of targets
function DataLoader:tsize(excludedim)
-- by default, batch dimension is excluded
excludedim = excludedim == nil and 1 or excludedim
error"Not Implemented"
end
-- called by AsyncIterator before serializing the DataLoader to threads
function DataLoader:reset()
self._indices = nil
self._start = nil
end
-- collect garbage every self.gcdelay times this method is called
function DataLoader:collectgarbage()
self.gcdelay = self.gcdelay or 200
self.gccount = (self.gccount or 0) + 1
if self.gccount >= self.gcdelay then
collectgarbage()
self.gccount = 0
end
end
function DataLoader:clone(...)
local f = torch.MemoryFile("rw"):binary()
f:writeObject(self)
f:seek(1)
local clone = f:readObject()
f:close()
if select('#',...) > 0 then
clone:share(self,...)
end
return clone
end
-- iterators : subiter, sampleiter
-- subiter : for iterating over validation and test sets
function DataLoader:subiter(batchsize, epochsize, ...)
batchsize = batchsize or 32
local dots = {...}
local size = self:size()
epochsize = epochsize or -1
epochsize = epochsize > 0 and epochsize or self:size()
self._start = self._start or 1
local nsampled = 0
local stop
local inputs, targets
-- build iterator
return function()
if nsampled >= epochsize then
return
end
local bs = math.min(nsampled+batchsize, epochsize) - nsampled
stop = math.min(self._start + bs - 1, size)
-- inputs and targets
local batch = {self:sub(self._start, stop, inputs, targets, unpack(dots))}
-- allows reuse of inputs and targets buffers for next iteration
inputs, targets = batch[1], batch[2]
bs = stop - self._start + 1
nsampled = nsampled + bs
self._start = self._start + bs
if self._start > size then
self._start = 1
end
self:collectgarbage()
return nsampled, unpack(batch)
end
end
-- sampleiter : for iterating over training sets
function DataLoader:sampleiter(batchsize, epochsize, ...)
batchsize = batchsize or 32
local dots = {...}
local size = self:size()
epochsize = epochsize or -1
epochsize = epochsize > 0 and epochsize or self:size()
local nsampled = 0
local inputs, targets
-- build iterator
return function()
if nsampled >= epochsize then
return
end
local bs = math.min(nsampled+batchsize, epochsize) - nsampled
-- inputs and targets
local batch = {self:sample(bs, inputs, targets, unpack(dots))}
-- allows reuse of inputs and targets buffers for next iteration
inputs, targets = batch[1], batch[2]
nsampled = nsampled + bs
self:collectgarbage()
return nsampled, unpack(batch)
end
end