From 8935496bf4696b240ca83cdbbbbd374a46f91571 Mon Sep 17 00:00:00 2001 From: Alexandre Boulch Date: Wed, 9 Dec 2020 10:26:42 +0000 Subject: [PATCH] networks - add pointnet segmentation --- lightconvpoint/networks/pointnet.py | 33 +++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/lightconvpoint/networks/pointnet.py b/lightconvpoint/networks/pointnet.py index 9568cbd..c76edef 100644 --- a/lightconvpoint/networks/pointnet.py +++ b/lightconvpoint/networks/pointnet.py @@ -6,7 +6,7 @@ class Pointnet(nn.Module): def __init__(self, in_channels, out_channels, - hidden_dim): + hidden_dim, segmentation=False): super().__init__() self.fc_in = nn.Conv1d(in_channels, 2*hidden_dim, 1) @@ -14,7 +14,13 @@ def __init__(self, in_channels, self.fc_1 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) self.fc_2 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) self.fc_3 = nn.Conv1d(2*hidden_dim, hidden_dim, 1) - self.fc_out = nn.Linear(hidden_dim, out_channels, 1) + + self.segmentation=segmentation + + if segmentation: + self.fc_out = nn.Conv1d(2*hidden_dim, out_channels, 1) + else: + self.fc_out = nn.Linear(hidden_dim, out_channels) self.activation = nn.ReLU() @@ -36,8 +42,12 @@ def forward(self, x): x = self.fc_3(self.activation(x)) - x = torch.max(x, dim=2)[0] - + if self.segmentation: + x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x) + x = torch.cat([x, x_pool], dim=1) + else: + x = torch.max(x, dim=2)[0] + x = self.fc_out(x) return x @@ -77,7 +87,7 @@ class ResidualPointnet(nn.Module): hidden_dim (int): hidden dimension of the network ''' - def __init__(self, in_channels, out_channels, hidden_dim): + def __init__(self, in_channels, out_channels, hidden_dim, segmentation=False): super().__init__() self.fc_in = nn.Conv1d(in_channels, 2*hidden_dim, 1) @@ -86,7 +96,12 @@ def __init__(self, in_channels, out_channels, hidden_dim): self.block_2 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim) self.block_3 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim) self.block_4 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim) - self.fc_out = nn.Linear(hidden_dim, out_channels) + + self.segmentation = segmentation + if self.segmentation: + self.fc_out = nn.Conv1d(2*hidden_dim, out_channels, 1) + else: + self.fc_out = nn.Linear(hidden_dim, out_channels) def forward(self, x): @@ -111,7 +126,11 @@ def forward(self, x): x = self.block_4(x) - x = torch.max(x, dim=2)[0] + if self.segmentation: + x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x) + x = torch.cat([x, x_pool], dim=1) + else: + x = torch.max(x, dim=2)[0] x = self.fc_out(x)