From eba265972dc86b100bf2ba963a916bc7b8e54ee4 Mon Sep 17 00:00:00 2001 From: jyotirmay Date: Fri, 29 Jan 2021 15:51:55 +0100 Subject: [PATCH] adding octave convolution module --- nn_common_modules/modules.py | 233 +++++++++++++++++- nn_common_modules/octave_convolution_block.py | 140 +++++++++++ 2 files changed, 372 insertions(+), 1 deletion(-) create mode 100644 nn_common_modules/octave_convolution_block.py diff --git a/nn_common_modules/modules.py b/nn_common_modules/modules.py index 80b42a7..9238005 100644 --- a/nn_common_modules/modules.py +++ b/nn_common_modules/modules.py @@ -17,7 +17,7 @@ import torch.nn as nn from squeeze_and_excitation import squeeze_and_excitation as se import torch.nn.functional as F - +from octave_convolution_block import * class DenseBlock(nn.Module): """Block with dense connections @@ -489,3 +489,234 @@ def forward(self, input, out_block=None, indices=None): x1 = self.conv(concat) x2 = self.relu(x1) return x2 + + +class OctaveDenseBlock(nn.Module): + """Block with dense connections + + :param params: { + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':28, + 'se_block': se.SELayer.None, + 'drop_out':0,2} + :type params: dict + :param se_block_type: Squeeze and Excite block type to be included, defaults to None + :type se_block_type: str, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'}, optional + :return: forward passed tensor + :rtype: torch.tonsor [FloatTensor] + """ + + def __init__(self, params, se_block_type=None, is_decoder=False): + super(OctaveDenseBlock, self).__init__() + print(se_block_type) + if se_block_type == se.SELayer.CSE.value: + self.SELayer = se.ChannelSELayer(params['num_filters']) + + elif se_block_type == se.SELayer.SSE.value: + self.SELayer = se.SpatialSELayer(params['num_filters']) + + elif se_block_type == se.SELayer.CSSE.value: + self.SELayer = se.ChannelSpatialSELayer(params['num_filters']) + else:abdominal_segmentation_2] - 1) / 2) + padding_w = int((params['kernel_w'] - 1) / 2) + + conv1_out_size = int(params['num_channels'] + params['num_filters']) + conv2_out_size = int( + params['num_channels'] + params['num_filters'] + params['num_filters']) + + self.conv1 = OctConv2d('first', in_channels=params['num_channels'], out_channels=params['num_filters'], kernel_size=( + params['kernel_h'], params['kernel_w']), padding=(padding_h, padding_w), + stride=params['stride_conv']) + + + self.conv2 = OctConv2d('regular', in_channels=conv1_out_size, out_channels=params['num_filters'], kernel_size=( + params['kernel_h'], params['kernel_w']), padding=(padding_h, padding_w), + stride=params['stride_conv']) + + + self.conv3 = OctConv2d('last', in_channels=conv2_out_size, out_channels=params['num_filters'], kernel_size=( + params['kernel_h'], params['kernel_w']), padding=(padding_h, padding_w), + stride=params['stride_conv']) + + self.avgpool1 = nn.AvgPool2d(2) + + alpha_in, alpha_out = 0.5, 0.5 + + if is_decoder: + # Channel sizes for dcoder blocks + self.batchnorm1 = nn.BatchNorm2d(params['num_channels']) + oct_unit_decoder_channel_size = params['num_channels'] // 4 #128 + self.batchnorm2_h = nn.BatchNorm2d(int(oct_unit_decoder_channel_size*3)) #conv1_out_size * (1 - alpha_out))+1) + self.batchnorm2_l = nn.BatchNorm2d(int(oct_unit_decoder_channel_size*3)) #conv1_out_size-1)) + + self.batchnorm3_h = nn.BatchNorm2d(int(oct_unit_decoder_channel_size*4)) #conv2_out_size * (1 - alpha_out))+1) + self.batchnorm3_l = nn.BatchNorm2d(int(oct_unit_decoder_channel_size*4)) #conv2_out_size-1)) + else: + # channel sizes for encoder blocks + self.batchnorm1 = nn.BatchNorm2d(params['num_channels']) + oct_unit_encoder_channel_size = params['num_channels'] // 2 #64 + self.batchnorm2_h = nn.BatchNorm2d(int(oct_unit_encoder_channel_size*2)) #conv1_out_size * (1 - alpha_out) + self.batchnorm2_l = nn.BatchNorm2d(int(oct_unit_encoder_channel_size*2)) #(conv1_out_size * alpha_out)) + + self.batchnorm3_h = nn.BatchNorm2d(int(oct_unit_encoder_channel_size*3)) #conv2_out_size * (1 - alpha_out))) + self.batchnorm3_l = nn.BatchNorm2d(int(oct_unit_encoder_channel_size*3)) #conv2_out_size * alpha_out)) + + self.prelu = nn.PReLU() + self.prelu_h = nn.PReLU() + self.prelu_l = nn.PReLU() + + if params['drop_out'] > 0: + self.drop_out_needed = True + self.drop_out = nn.Dropout2d(params['drop_out']) + else: + self.drop_out_needed = False + + def forward(self, input): + """Forward pass + + :param input: Input tensor, shape = (N x C x H x W) + :type input: torch.tensor [FloatTensor] + :return: Forward passed tensor + :rtype: torch.tensor [FloatTensor] + """ + + o1 = self.batchnorm1(input) + o2 = self.prelu(o1) + + ch = input.shape[1] // 2 + inp_h, inp_l = input[:, :ch, :, :], input[:, ch:, :, :] + inp_ll = self.avgpool1(inp_l) + + o3_h, o3_l = self.conv1(o2) + o4_h = torch.cat((inp_h, o3_h), dim=1) + o4_l = torch.cat((inp_ll, o3_l), dim=1) + o5_h = self.batchnorm2_h(o4_h) + o5_l = self.batchnorm2_l(o4_l) + o6_h = self.prelu_h(o5_h) + o6_l = self.prelu_l(o5_l) + o7_h, o7_l = self.conv2((o6_h, o6_l)) + + o8_h = torch.cat((inp_h, o3_h, o7_h), dim=1) + o8_l = torch.cat((inp_ll, o3_l, o7_l), dim=1) + + o9_h = self.batchnorm3_h(o8_h) + o9_l = self.batchnorm3_l(o8_l) + o10_h = self.prelu_h(o9_h) + o10_l = self.prelu_l(o9_l) + out = self.conv3((o10_h, o10_l)) + + return out + + +class OctaveEncoderBlock(OctaveDenseBlock): + """Dense encoder block with maxpool and an optional SE block + + :param params: { + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':28, + 'se_block': se.SELayer.None, + 'drop_out':0,2} + :type params: dict + :param se_block_type: Squeeze and Excite block type to be included, defaults to None + :type se_block_type: str, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'}, optional + :return: output tensor with maxpool, output tensor without maxpool, indices for unpooling + :rtype: torch.tensor [FloatTensor], torch.tensor [FloatTensor], torch.tensor [LongTensor] + """ + + def __init__(self, params, se_block_type=None): + super(OctaveEncoderBlock, self).__init__(params, se_block_type=se_block_type, is_decoder=False) + self.maxpool = nn.MaxPool2d( + kernel_size=params['pool'], stride=params['stride_pool'], return_indices=True) + + def forward(self, input, weights=None): + """Forward pass + + :param input: Input tensor, shape = (N x C x H x W) + :type input: torch.tensor [FloatTensor] + :param weights: Weights used for squeeze and excitation, shape depends on the type of SE block, defaults to None + :type weights: torch.tensor, optional + :return: output tensor with maxpool, output tensor without maxpool, indices for unpooling + :rtype: torch.tensor [FloatTensor], torch.tensor [FloatTensor], torch.tensor [LongTensor] + """ + + out_block = super(OctaveEncoderBlock, self).forward(input) + if self.SELayer: + out_block = self.SELayer(out_block) + + if self.drop_out_needed: + out_block = self.drop_out(out_block) + + out_encoder, indices = self.maxpool(out_block) + return out_encoder, out_block, indices + + +class OctaveDecoderBlock(OctaveDenseBlock): + """Dense decoder block with maxunpool and an optional skip connections and SE block + + :param params: { + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':28, + 'se_block': se.SELayer.None, + 'drop_out':0,2} + :type params: dict + :param se_block_type: Squeeze and Excite block type to be included, defaults to None + :type se_block_type: str, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'}, optional + :return: forward passed tensor + :rtype: torch.tensor [FloatTensor] + """ + + def __init__(self, params, se_block_type=None): + super(OctaveDecoderBlock, self).__init__(params, se_block_type=se_block_type, is_decoder=True) + self.unpool = nn.MaxUnpool2d( + kernel_size=params['pool'], stride=params['stride_pool']) + + def forward(self, input, out_block=None, indices=None, weights=None): + """Forward pass + + :param input: Input tensor, shape = (N x C x H x W) + :type input: torch.tensor [FloatTensor] + :param out_block: Tensor for skip connection, shape = (N x C x H x W), defaults to None + :type out_block: torch.tensor [FloatTensor], optional + :param indices: Indices used for unpooling operation, defaults to None + :type indices: torch.tensor, optional + :param weights: Weights used for squeeze and excitation, shape depends on the type of SE block, defaults to None + :type weights: torch.tensor, optional + :return: Forward passed tensor + :rtype: torch.tensor [FloatTensor] + """ + if indices is not None: + unpool = self.unpool(input, indices) + else: + # TODO: Implement Conv Transpose + print("You have to use Conv Transpose") + + if out_block is not None: + concat = torch.cat((out_block, unpool), dim=1) + else: + concat = unpool + out_block = super(OctaveDecoderBlock, self).forward(concat) + + if self.SELayer: + out_block = self.SELayer(out_block) + + if self.drop_out_needed: + out_block = self.drop_out(out_block) + return out_block diff --git a/nn_common_modules/octave_convolution_block.py b/nn_common_modules/octave_convolution_block.py new file mode 100644 index 0000000..50bc582 --- /dev/null +++ b/nn_common_modules/octave_convolution_block.py @@ -0,0 +1,140 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from math import ceil,floor + +from torch.nn.modules.utils import _single, _pair, _triple + +class OctConv2d(nn.modules.conv._ConvNd): + """Unofficial implementation of the Octave Convolution in the "Drop an Octave" paper. + oct_type (str): The type of OctConv you'd like to use. ['first', 'A'] both stand for the the first Octave Convolution. + ['last', 'C'] both stand for th last Octave Convolution. And 'regular' stand for the regular ones. + """ + + def __init__(self, oct_type, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, alpha_in=0.5, alpha_out=0.5): + + if oct_type not in ('regular', 'first', 'last', 'A', 'C'): + raise InvalidOctType("Invalid oct_type was chosen!") + + oct_type_dict = {'first': (0, alpha_out), 'A': (0, alpha_out), 'last': (alpha_in, 0), 'C': (alpha_in, 0), + 'regular': (alpha_in, alpha_out)} + + kernel_size = _pair(kernel_size) + stride = _pair(stride) + + # TODO: Make it work with any padding + padding = _pair(int((kernel_size[0] - 1) / 2)) + # padding = _pair(padding) + dilation = _pair(dilation) + super(OctConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), 1, bias, padding_mode='zeros') + + # Get alphas from the oct_type_dict + self.oct_type = oct_type + self.alpha_in, self.alpha_out = oct_type_dict[self.oct_type] + + self.num_high_in_channels = int((1 - self.alpha_in) * in_channels) + self.num_low_in_channels = int(self.alpha_in * in_channels) + self.num_high_out_channels = int((1 - self.alpha_out) * out_channels) + self.num_low_out_channels = int(self.alpha_out * out_channels) + + self.high_hh_weight = self.weight[:self.num_high_out_channels, :self.num_high_in_channels, :, :].clone() + self.high_hh_bias = self.bias[:self.num_high_out_channels].clone() + + self.high_hl_weight = self.weight[self.num_high_out_channels:, :self.num_high_in_channels, :, :].clone() + self.high_hl_bias = self.bias[self.num_high_out_channels:].clone() + + self.low_lh_weight = self.weight[:self.num_high_out_channels, self.num_high_in_channels:, :, :].clone() + self.low_lh_bias = self.bias[:self.num_high_out_channels].clone() + + self.low_ll_weight = self.weight[self.num_high_out_channels:, self.num_high_in_channels:, :, :].clone() + self.low_ll_bias = self.bias[self.num_high_out_channels:].clone() + + self.high_hh_weight.data, self.high_hl_weight.data, self.low_lh_weight.data, self.low_ll_weight.data = \ + self._apply_noise(self.high_hh_weight.data), self._apply_noise(self.high_hl_weight.data), \ + self._apply_noise(self.low_lh_weight.data), self._apply_noise(self.low_ll_weight.data) + + self.high_hh_weight, self.high_hl_weight, self.low_lh_weight, self.low_ll_weight = \ + nn.Parameter(self.high_hh_weight), nn.Parameter(self.high_hl_weight), nn.Parameter(self.low_lh_weight), nn.Parameter(self.low_ll_weight) + + self.high_hh_bias, self.high_hl_bias, self.low_lh_bias, self.low_ll_bias = \ + nn.Parameter(self.high_hh_bias), nn.Parameter(self.high_hl_bias), nn.Parameter(self.low_lh_bias), nn.Parameter(self.low_ll_bias) + + + self.avgpool = nn.AvgPool2d(2) + + def forward(self, x): + if self.oct_type in ('first', 'A'): + high_group, low_group = x[:, :self.num_high_in_channels, :, :], x[:, self.num_high_in_channels:, :, :] + else: + high_group, low_group = x + + high_group_hh = F.conv2d(high_group, self.high_hh_weight, self.high_hh_bias, self.stride, + self.padding, self.dilation, self.groups) + high_group_pooled = self.avgpool(high_group) + + if self.oct_type in ('first', 'A'): + high_group_hl = F.conv2d(high_group_pooled, self.high_hl_weight, self.high_hl_bias, self.stride, + self.padding, self.dilation, self.groups) + high_group_out, low_group_out = high_group_hh, high_group_hl + + return high_group_out, low_group_out + + elif self.oct_type in ('last', 'C'): + low_group_lh = F.conv2d(low_group, self.low_lh_weight, self.low_lh_bias, self.stride, + self.padding, self.dilation, self.groups) + low_group_upsampled = F.interpolate(low_group_lh, scale_factor=2) + high_group_out = high_group_hh + low_group_upsampled + + return high_group_out + + else: + high_group_hl = F.conv2d(high_group_pooled, self.high_hl_weight, self.high_hl_bias, self.stride, + self.padding, self.dilation, self.groups) + low_group_lh = F.conv2d(low_group, self.low_lh_weight, self.low_lh_bias, self.stride, + self.padding, self.dilation, self.groups) + low_group_upsampled = F.interpolate(low_group_lh, scale_factor=2) + low_group_ll = F.conv2d(low_group, self.low_ll_weight, self.low_ll_bias, self.stride, + self.padding, self.dilation, self.groups) + + high_group_out = high_group_hh + low_group_upsampled + low_group_out = high_group_hl + low_group_ll + + return high_group_out, low_group_out + + @staticmethod + def _apply_noise(tensor, mu=0, sigma=0.0001): + noise = torch.normal(mean=torch.ones_like(tensor) * mu, std=torch.ones_like(tensor) * sigma) + + return tensor + noise + + +class OctReLU(nn.Module): + def __init__(self, inplace=False): + super().__init__() + self.relu_h, self.relu_l = nn.ReLU(inplace), nn.ReLU(inplace) + + def forward(self, x): + h, l = x + + return self.relu_h(h), self.relu_l(l) + + +class OctMaxPool2d(nn.Module): + def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): + super().__init__() + self.maxpool_h = nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False) + self.maxpool_l = nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False) + + def forward(self, x): + h, l = x + + return self.maxpool_h(h), self.maxpool_l(l) + + +class Error(Exception): + """Base-class for all exceptions rased by this module.""" + + +class InvalidOctType(Error): + """There was a problem in the OctConv type.""" \ No newline at end of file