-
Notifications
You must be signed in to change notification settings - Fork 2
/
Super_resolution.py
87 lines (82 loc) · 3.03 KB
/
Super_resolution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch.nn as nn
import torch
class one_conv(nn.Module):
def __init__(self,inchanels,growth_rate,kernel_size = 3):
super(one_conv,self).__init__()
self.conv = nn.Conv2d(inchanels,growth_rate,kernel_size=kernel_size,padding = kernel_size>>1,stride= 1)
self.relu = nn.ReLU()
def forward(self,x):
output = self.relu(self.conv(x))
return torch.cat((x,output),1)
class RDB(nn.Module):
def __init__(self,G0,C,G,kernel_size = 3):
super(RDB,self).__init__()
convs = []
for i in range(C):
convs.append(one_conv(G0+i*G,G))
self.conv = nn.Sequential(*convs)
#local_feature_fusion
self.LFF = nn.Conv2d(G0+C*G,G0,kernel_size = 1,padding = 0,stride =1)
def forward(self,x):
out = self.conv(x)
lff = self.LFF(out)
#local residual learning
return lff + x
class rdn(nn.Module):
def __init__(self):
'''
opts: the system para
'''
super(rdn,self).__init__()
'''
D: RDB number 20
C: the number of conv layer in RDB 6
G: the growth rate 32
G0:local and global feature fusion layers 64filter
'''
self.D = 20
self.C = 6
self.G = 32
self.G0 = 64
kernel_size = 3
input_channels = 1
#shallow feature extraction
self.SFE1 = nn.Conv2d(input_channels,self.G0,kernel_size=kernel_size,padding = kernel_size>>1,stride= 1)
self.SFE2 = nn.Conv2d(self.G0,self.G0,kernel_size=kernel_size,padding = kernel_size>>1,stride =1)
#RDB for paper we have D RDB block
self.RDBS = nn.ModuleList()
for d in range(self.D):
self.RDBS.append(RDB(self.G0,self.C,self.G,kernel_size))
#Global feature fusion
self.GFF = nn.Sequential(
nn.Conv2d(self.D*self.G0,self.G0,kernel_size = 1,padding = 0 ,stride= 1),
nn.Conv2d(self.G0,self.G0,kernel_size,padding = kernel_size>>1,stride = 1),
)
#upsample net
self.up_net = nn.Sequential(
nn.Conv2d(self.G0,self.G*4,kernel_size=kernel_size,padding = kernel_size>>1,stride = 1),
nn.PixelShuffle(2),
nn.Conv2d(self.G,self.G*4,kernel_size = kernel_size,padding =kernel_size>>1,stride = 1),
nn.PixelShuffle(2),
nn.Conv2d(self.G,1,kernel_size=kernel_size,padding = kernel_size>>1,stride = 1)
)
#init
for para in self.modules():
if isinstance(para,nn.Conv2d):
nn.init.orthogonal_(para.weight)
if para.bias is not None:
para.bias.data.zero_()
def forward(self,x):
#f-1
f__1 = self.SFE1(x)
out = self.SFE2(f__1)
RDB_outs = []
for i in range(self.D):
out = self.RDBS[i](out)
RDB_outs.append(out)
out = torch.cat(RDB_outs,1)
out = self.GFF(out)
out = f__1+out
out = self.up_net(out)
out = out[:,:,::4,:]
return out