-
Notifications
You must be signed in to change notification settings - Fork 2
/
registernumbers.lua
115 lines (107 loc) · 3.09 KB
/
registernumbers.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
local register_ = require 'torch.register'
local torch = require 'torch.env'
local function copy_args(args)
local tbl = {}
for k,v in pairs(args) do
tbl[k] = v
end
return tbl
end
-- handle numbers type
local function register(args, namespace, metatable)
local nidx
for idx,arg in ipairs(args) do
if arg.type == 'numbers' then
if nidx then
error('only one argument can be of <numbers> type')
end
nidx = idx
end
end
if nidx then
assert(args.call, '<numbers> is supposed to be used together with <call>')
-- with table
local new_args = copy_args(args)
new_args[nidx] = copy_args(new_args[nidx]) -- avoid modification with no warning
new_args[nidx].type = 'table'
local funcargs = {}
local callargs = {}
for i=1,#new_args do
table.insert(funcargs, string.format('arg%d', i))
table.insert(callargs, string.format('arg%d', i))
end
callargs[nidx] = 'numbers'
funcargs = table.concat(funcargs, ', ')
callargs = table.concat(callargs, ', ')
local numbers = torch.LongStorage()
local code = [[
local call
local numbers
return function(%s)
local sz = #arg%d
numbers:resize(sz)
for i=1,sz do
numbers.__data[i-1] = arg%d[i]
end
return call(%s)
end
]]
code = string.format(code, funcargs, nidx, nidx, callargs)
code = loadstring(code)()
debug.setupvalue(code, 1, numbers)
debug.setupvalue(code, 2, args.call)
new_args.call = code
register_(new_args, namespace, metatable)
-- with numbers (up to N)
local N = 5
local new_args = copy_args(args)
table.remove(new_args, nidx)
for i=1,N do
local arg = copy_args(args[nidx])
arg.name = arg.name .. i
arg.type = "number"
if i > 1 then
arg.default = 0
end
table.insert(new_args, nidx+i-1, arg)
end
local funcargs = {}
local callargs = {}
for i=1,#new_args do
table.insert(funcargs, string.format('arg%d', i))
table.insert(callargs, string.format('arg%d', i))
end
callargs[nidx] = 'numbers'
for i=2,N do
table.remove(callargs, nidx+1)
end
funcargs = table.concat(funcargs, ', ')
callargs = table.concat(callargs, ', ')
local numbers = torch.LongStorage(5)
local code = [[
local call
local numbers
return function(%s)
numbers.__data[0] = arg%d
numbers.__data[1] = arg%d
numbers.__data[2] = arg%d
numbers.__data[3] = arg%d
numbers.__data[4] = arg%d
return call(%s)
end
]]
code = string.format(code, funcargs, nidx, nidx+1, nidx+2, nidx+3, nidx+4, callargs)
code = loadstring(code)()
debug.setupvalue(code, 1, numbers)
debug.setupvalue(code, 2, args.call)
new_args.call = code
register_(new_args, namespace, metatable)
-- with LongStorage
local new_args = copy_args(args)
args[nidx].type = 'torch.LongStorage'
register_(new_args, namespace, metatable)
else
register_(args, namespace, metatable)
end
end
return register