-
Notifications
You must be signed in to change notification settings - Fork 0
/
ODConv2d.py
155 lines (115 loc) · 5.82 KB
/
ODConv2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import math
import torch
import torch.nn as nn
class ODConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, groups=1, bias=True,
K=4, r=1 / 16, save_parameters=False,
padding_mode='zeros', device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
self.K = K
self.r = r
self.save_parameters = save_parameters
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias, padding_mode)
del self.weight
self.weight = nn.Parameter(torch.empty((
K,
out_channels,
in_channels // groups,
*self.kernel_size,
), **factory_kwargs))
if bias:
del self.bias
self.bias = nn.Parameter(torch.empty(K, out_channels, **factory_kwargs))
hidden_dim = int(in_channels * r)
self.gap = nn.AdaptiveAvgPool2d(1)
self.reduction = nn.Linear(in_channels, hidden_dim)
self.act = nn.ReLU(inplace=True)
self.fc_f = nn.Linear(hidden_dim, out_channels)
if not save_parameters or self.kernel_size[0] * self.kernel_size[1] > 1:
self.fc_s = nn.Linear(hidden_dim, self.kernel_size[0] * self.kernel_size[1])
if not save_parameters or in_channels // groups > 1:
self.fc_c = nn.Linear(hidden_dim, in_channels // groups)
if not save_parameters or K > 1:
self.fc_w = nn.Linear(hidden_dim, K)
self.reset_parameters()
def reset_parameters(self) -> None:
fan_out = self.kernel_size[0] * self.kernel_size[1] * self.out_channels // self.groups
for i in range(self.K):
self.weight.data[i].normal_(0, math.sqrt(2.0 / fan_out))
if self.bias is not None:
self.bias.data.zero_()
def extra_repr(self):
return super().extra_repr() + f', K={self.K}, r={self.r:.4}'
def get_weight_bias(self, context):
B, C, H, W = context.shape
if C != self.in_channels:
raise ValueError(
f"Expected context{[B, C, H, W]} to have {self.in_channels} channels, but got {C} channels instead")
x = self.gap(context).squeeze(-1).squeeze(-1) # B, c_in
x = self.reduction(x) # B, hidden_dim
x = self.act(x)
attn_f = self.fc_f(x).sigmoid() # B, c_out
attn = attn_f.view(B, 1, -1, 1, 1, 1) # B, 1, c_out, 1, 1, 1
if hasattr(self, 'fc_s'):
attn_s = self.fc_s(x).sigmoid() # B, k * k
attn = attn * attn_s.view(B, 1, 1, 1, *self.kernel_size) # B, 1, c_out, 1, k, k
if hasattr(self, 'fc_c'):
attn_c = self.fc_c(x).sigmoid() # B, c_in // groups
attn = attn * attn_c.view(B, 1, 1, -1, 1, 1) # B, 1, c_out, c_in // groups, k, k
if hasattr(self, 'fc_w'):
attn_w = self.fc_w(x).softmax(-1) # B, n
attn = attn * attn_w.view(B, -1, 1, 1, 1, 1) # B, n, c_out, c_in // groups, k, k
weight = (attn * self.weight).sum(1) # B, c_out, c_in // groups, k, k
weight = weight.view(-1, self.in_channels // self.groups, *self.kernel_size) # B * c_out, c_in // groups, k, k
bias = None
if self.bias is not None:
if hasattr(self, 'fc_w'):
bias = attn_w @ self.bias
else:
bias = self.bias.tile(B, 1)
bias = bias.view(-1) # B * c_out
return weight, bias
def forward(self, input, context=None):
B, C, H, W = input.shape
if C != self.in_channels:
raise ValueError(
f"Expected input{[B, C, H, W]} to have {self.in_channels} channels, but got {C} channels instead")
weight, bias = self.get_weight_bias(context or input)
output = nn.functional.conv2d(
input.view(1, B * C, H, W), weight, bias,
self.stride, self.padding, self.dilation, B * self.groups) # 1, B * c_out, h_out, w_out
output = output.view(B, self.out_channels, *output.shape[2:])
return output
def debug(self, input, context=None):
B, C, H, W = input.shape
if C != self.in_channels:
raise ValueError(
f"Expected input{[B, C, H, W]} to have {self.in_channels} channels, but got {C} channels instead")
output_size = [
((H, W)[i] + 2 * self.padding[i] - self.dilation[i] * (self.kernel_size[i] - 1) - 1) // self.stride[i] + 1
for i in range(2)
]
weight, bias = self.get_weight_bias(context or input)
weight = weight.view(B, self.groups, self.out_channels // self.groups, -1) # B, groups, c_out // groups, c_in // groups * k * k
unfold = nn.functional.unfold(
input, self.kernel_size, self.dilation, self.padding, self.stride) # B, c_in * k * k, H_out * W_out
unfold = unfold.view(B, self.groups, -1, output_size[0] * output_size[1]) # B, groups, c_in // groups * k * k, H_out * W_out
output = weight @ unfold # B, groups, c_out // groups, H_out * W_out
output = output.view(B, self.out_channels, *output_size) # B, c_out, H_out * W_out
if bias is not None:
output = output + bias.view(B, self.out_channels, 1, 1)
return output
if __name__ == "__main__":
import torch
x = torch.randn(2, 60, 5, 5)
conv = nn.Conv2d(60, 60, 7, padding=3, groups=60)
odconv = ODConv2d(60, 60, 7, padding=3, groups=60, K=2, r=1/6)
out1 = odconv(x)
out2 = odconv.debug(x)
assert torch.allclose(out1, out2, atol=1e-7), (out1 - out2).abs().max()
print(odconv)
print("parameters:", sum(p.numel() for p in odconv.parameters() if p.requires_grad))
print(conv)
print("parameters:", sum(p.numel() for p in conv.parameters() if p.requires_grad))