Skip to content

Commit

Permalink
Semantic segmentation NPM3D initial release
Browse files Browse the repository at this point in the history
  • Loading branch information
aboulch committed Aug 19, 2020
1 parent 5f20402 commit 8fbf5c3
Show file tree
Hide file tree
Showing 12 changed files with 874 additions and 14 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ If you use the FKAConv code or the LightConvPoint framework in your research, pl
We provide examples classification and segmentation datasets:
* ModelNet40
* ShapeNet
* S3DIS (*to be released*)
* Semantic8 (*to be released*)
* NPM3D (*to be released*)
* S3DIS
* Semantic8
* NPM3D

22 changes: 15 additions & 7 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,36 @@ python train.py with shapenet.json

The S3DIS is a large indoor dataset for point cloud semantic segmentation.

#### Data preparation
**Data preparation.** We use the data preparation from ConvPoint to create the point cloud text files (https://github.com/aboulch/ConvPoint/tree/master/examples/s3dis).

We use the data preparation from ConvPoint to create the point cloud text files (https://github.com/aboulch/ConvPoint/tree/master/examples/s3dis).
### [Semantic8](http://semantic3d.net/)

#### Training
**Data preparation.** We use the data preparation from ConvPoint to create the point cloud text files (https://github.com/aboulch/ConvPoint/tree/master/examples/semantic3d) with a given voxel size.

**Benchmark file creation.** We use the projection from decimated pointcloud to original ones of ConvPoint (https://github.com/aboulch/ConvPoint/tree/master/examples/semantic3d).

### [NPM3D](https://npm3d.fr/paris-lille-3d)

**Data preparation.** The `prepare_data.py` script splits the training files into smaller files for easy loading.

### Training

To modify the settings, e.g., area, save directory... you can either modify the yaml file, create a copy of the original file with modified arguments or change the options directly in the command line.
```bash
python train.py # will automatically call the modified yaml file
python train.py with new_config_file.yaml # call the default config file and then update parameters with the new one
python train.py with area=X training.savedir="new_savedir_path_areaX" # direct modification in the command line
python train.py with area=X training.savedir="new_savedir_path" # direct modification in the command line
```
The area parameter is the test area identifier, it will train the model on every other areas.
The area parameter (for S3DIS) is the test area identifier, it will train the model on every other areas.

#### Test
### Test

To test the model:
```bash
python test.py -c save_directory_path/config.yaml
```

#### Training and testing with a fusion model
### Training and testing with a fusion model

To train a fusion model, first train independently two models with and without color information (it is an option in the config file).
Then you can train the fusion model by modifying the `s3dis_fusion.yaml` file (mostly set the paths of the two previously trained models) and:
Expand Down
42 changes: 42 additions & 0 deletions examples/npm3d/npm3d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Dataset
dataset:
datasetdir: /datasets_local
dataset: NPM3D_processed/
npoints: 8192
pillar_size: 8

# Network
network:
model: KPConvSeg
backend_conv:
layer: FKAConv
kernel_separation: false
backend_search: SearchQuantized
fusion_submodel: null
fusion_submodeldir: null

# Training
training:
savedir: path_to_save_directory
batchsize: 16
jitter: 0.4
scaling_param: 0
rgb: true
rgb_dropout: false
lr_start: 0.001
epoch_iter: 1000
epoch_nbr: 100
weights: false

# Testing
test:
step: 2
batchsize: 16
savepts: true
savepreds: false

# misc
misc:
device: cuda
disable_tqdm: false
threads: 4
252 changes: 252 additions & 0 deletions examples/npm3d/npm3d_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@

import torch
import numpy as np
import lightconvpoint.nn
import os
import random
from torchvision import transforms
from PIL import Image
from tqdm import *
from plyfile import PlyData, PlyElement
from lightconvpoint.nn import with_indices_computation_rotation

def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1],])
return np.dot(batch_data, rotation_matrix)

# Part dataset only for training / validation
class DatasetTrainVal():

def compute_mask(self, xyzrgb, pt, bs):
# build the mask
mask_x = np.logical_and(xyzrgb[:,0]<=pt[0]+bs/2, xyzrgb[:,0]>=pt[0]-bs/2)
mask_y = np.logical_and(xyzrgb[:,1]<=pt[1]+bs/2, xyzrgb[:,1]>=pt[1]-bs/2)
mask = np.logical_and(mask_x, mask_y)
return mask


def __init__ (self, filelist, folder,
training=False,
block_size=8,
npoints = 8192,
jitter = 0,
iteration_number = None,
rgb_dropout=False,
rgb=True,
in_memory=True,
network_function=None):

self.filelist = filelist
self.folder = folder
self.training = training
self.bs = block_size
self.npoints = npoints
self.iterations = iteration_number
self.rgb_dropout = rgb_dropout
self.rgb=rgb
self.in_memory = in_memory
self.jitter = jitter

self.transform = transforms.ColorJitter(
brightness=jitter,
contrast=jitter,
saturation=jitter)

if network_function is not None:
self.net = network_function()
else:
self.net = None

if self.in_memory:
self.data = []
for filename in filelist:
data = np.load(os.path.join(self.folder, filename))
self.data.append(data)

@with_indices_computation_rotation
def __getitem__(self, index):

# load the data
index = random.randint(0, len(self.filelist)-1)
if self.in_memory:
pts = self.data[index]
else:
pts = np.load(os.path.join(self.folder, self.filelist[index]))

# get the features
fts = np.tile(pts[:,3], [3,1]).transpose()
# get the labels
lbs = pts[:,4].astype(int) # the generation script label starts at 1
# get the point coordinates
pts = pts[:, :3]

# pick a random point
pt_id = random.randint(0, pts.shape[0]-1)
pt = pts[pt_id]

# create the mask
mask = self.compute_mask(pts[:,:2], pt, self.bs)
pts = pts[mask]
lbs = lbs[mask]
fts = fts[mask]

# random selection
if pts.shape[0] < self.npoints:
choice = np.random.choice(pts.shape[0], self.npoints, replace=True)
else:
choice = np.random.choice(pts.shape[0], self.npoints, replace=False)
pts = pts[choice]
lbs = lbs[choice]
fts = fts[choice]

# data augmentation
if self.training:
# random rotation
pts = rotate_point_cloud_z(pts)

# random jittering
fts = fts.astype(np.uint8)
fts = np.array(self.transform( Image.fromarray(np.expand_dims(fts, 0)) ))
fts = np.squeeze(fts, 0)

fts = fts.astype(np.float32)
fts = fts/255 - 0.5

pts = torch.from_numpy(pts).float()
fts = torch.from_numpy(fts).float()
lbs = torch.from_numpy(lbs).long()

if (not self.rgb) or (self.training and self.rgb_dropout and random.randint(0,1)):
fts = torch.ones(fts.shape).float()

pts = pts.transpose(0,1)
fts = fts.transpose(0,1)

return_dict = {
"pts": pts,
"features": fts,
"target": lbs,
}

return return_dict

def __len__(self):
return self.iterations

class DatasetTest():

def compute_mask(self, xyzrgb, pt, bs):
# build the mask
mask_x = np.logical_and(xyzrgb[:,0]<=pt[0]+bs/2, xyzrgb[:,0]>=pt[0]-bs/2)
mask_y = np.logical_and(xyzrgb[:,1]<=pt[1]+bs/2, xyzrgb[:,1]>=pt[1]-bs/2)
mask = np.logical_and(mask_x, mask_y)
return mask

def __init__ (self, filename, folder,
block_size=8,
npoints = 8192, rgb=True,
network_function=None, step=None, offset=0):

self.folder = folder
self.bs = block_size
self.npoints = npoints
self.filename = filename
self.rgb=rgb

step = block_size if step is None else step

if network_function is not None:
self.net = network_function()
else:
self.net = None


# get the data
plydata = PlyData.read(os.path.join(folder, filename))
x = plydata["vertex"].data["x"].astype(np.float32)
y = plydata["vertex"].data["y"].astype(np.float32)
z = plydata["vertex"].data["z"].astype(np.float32)
reflectance = plydata["vertex"].data["reflectance"].astype(np.float32)
self.xyzrgb = np.stack([x,y,z,reflectance], axis=1).astype(np.float32)

mini = self.xyzrgb[:,:2].min(0)
discretized = ((self.xyzrgb[:,:2]-mini+offset).astype(float)/step).astype(int)
self.pts = np.unique(discretized, axis=0)
self.pts = self.pts.astype(np.float)*step + mini - offset + step/2

# compute the masks
self.choices = []
self.pts_ref = []
for index in tqdm(range(self.pts.shape[0]), ncols=80):
pt_ref = self.pts[index]
mask = self.compute_mask(self.xyzrgb, pt_ref, self.bs)

pillar_points_indices = np.where(mask)[0]
valid_points_indices = pillar_points_indices.copy()

while(valid_points_indices is not None):
# print(valid_points_indices.shape[0])
if valid_points_indices.shape[0] > self.npoints:
choice = np.random.choice(valid_points_indices.shape[0], self.npoints, replace=True)
mask[valid_points_indices[choice]] = False
choice = valid_points_indices[choice]
valid_points_indices = np.where(mask)[0]
else:
choice = np.random.choice(pillar_points_indices.shape[0], self.npoints-valid_points_indices.shape[0], replace=True)
choice = np.concatenate([valid_points_indices, pillar_points_indices[choice]], axis=0)
valid_points_indices = None

self.choices.append(choice)
self.pts_ref.append(pt_ref)

@with_indices_computation_rotation
def __getitem__(self, index):

choice = self.choices[index]
pts = self.xyzrgb[choice]
pt_ref = self.pts_ref[index]

# get the features
fts = np.tile(pts[:,3], [3,1]).transpose()
fts = fts.astype(np.float32) / 255 - 0.5
# get the point coordinates
pts = pts[:,:3]


# go numpy
pts = torch.from_numpy(pts).float()
fts = torch.from_numpy(fts).float()
choice = torch.from_numpy(choice).long()

if not self.rgb:
fts = torch.ones(fts.shape).float()

# transpose for light conv point
pts = pts.transpose(0,1)
fts = fts.transpose(0,1)

return_dict = {
"pts": pts,
"features": fts,
"pts_ids": choice
}

return return_dict

def __len__(self):
return len(self.choices)



46 changes: 46 additions & 0 deletions examples/npm3d/npm3d_fusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Dataset
dataset:
datasetdir: /datasets_local
dataset: NPM3D_processed/
npoints: 8192
pillar_size: 8

# Network
network:
model: Fusion
backend_conv:
layer: FKAConv
kernel_separation: false
backend_search: SearchQuantized
fusion_submodel:
- KPConvSeg
- KPConvSeg
fusion_submodeldir:
- path_to_modeldir_rgb
- path_to_modeldir_noColor

# Training
training:
savedir: path_to_save_directory
batchsize: 16
jitter: 0.4
scaling_param: 0
rgb: true
rgb_dropout: false
lr_start: 0.001
epoch_iter: 1000
epoch_nbr: 20
weights: false

# Testing
test:
step: 2
batchsize: 16
savepts: true
savepreds: false

# misc
misc:
device: cuda
disable_tqdm: false
threads: 4
Loading

0 comments on commit 8fbf5c3

Please sign in to comment.