Skip to content

Commit

Permalink
networks - add pointnet segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
aboulch committed Dec 9, 2020
1 parent ef3c512 commit 8935496
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions lightconvpoint/networks/pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,21 @@ class Pointnet(nn.Module):

def __init__(self, in_channels,
out_channels,
hidden_dim):
hidden_dim, segmentation=False):
super().__init__()

self.fc_in = nn.Conv1d(in_channels, 2*hidden_dim, 1)
self.fc_0 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
self.fc_1 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
self.fc_2 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
self.fc_3 = nn.Conv1d(2*hidden_dim, hidden_dim, 1)
self.fc_out = nn.Linear(hidden_dim, out_channels, 1)

self.segmentation=segmentation

if segmentation:
self.fc_out = nn.Conv1d(2*hidden_dim, out_channels, 1)
else:
self.fc_out = nn.Linear(hidden_dim, out_channels)

self.activation = nn.ReLU()

Expand All @@ -36,8 +42,12 @@ def forward(self, x):

x = self.fc_3(self.activation(x))

x = torch.max(x, dim=2)[0]

if self.segmentation:
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
x = torch.cat([x, x_pool], dim=1)
else:
x = torch.max(x, dim=2)[0]

x = self.fc_out(x)

return x
Expand Down Expand Up @@ -77,7 +87,7 @@ class ResidualPointnet(nn.Module):
hidden_dim (int): hidden dimension of the network
'''

def __init__(self, in_channels, out_channels, hidden_dim):
def __init__(self, in_channels, out_channels, hidden_dim, segmentation=False):
super().__init__()

self.fc_in = nn.Conv1d(in_channels, 2*hidden_dim, 1)
Expand All @@ -86,7 +96,12 @@ def __init__(self, in_channels, out_channels, hidden_dim):
self.block_2 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
self.block_3 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
self.block_4 = ResidualBlock(2*hidden_dim, hidden_dim, hidden_dim)
self.fc_out = nn.Linear(hidden_dim, out_channels)

self.segmentation = segmentation
if self.segmentation:
self.fc_out = nn.Conv1d(2*hidden_dim, out_channels, 1)
else:
self.fc_out = nn.Linear(hidden_dim, out_channels)


def forward(self, x):
Expand All @@ -111,7 +126,11 @@ def forward(self, x):

x = self.block_4(x)

x = torch.max(x, dim=2)[0]
if self.segmentation:
x_pool = torch.max(x, dim=2, keepdim=True)[0].expand_as(x)
x = torch.cat([x, x_pool], dim=1)
else:
x = torch.max(x, dim=2)[0]

x = self.fc_out(x)

Expand Down

0 comments on commit 8935496

Please sign in to comment.