diff --git a/.gitignore b/.gitignore index 2df7d0a..d19b3c2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +data/ example_data/ checkpoint/ wandb/ diff --git a/README.md b/README.md index d2ea5c7..8f81115 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,11 @@ You can train, test and visualize the results using our provided config files, o We explain our code design in [tutorial.md](docs/tutorial.md). Please read it before modifying the codebase or implementing your new algorithms. +## Benchmark + +We benchmark the baselines and report them in [model.md](docs/model.md). +We also present detailed instructions on how to reproduce our results. + ## License This project is released under the [MIT license](LICENSE). diff --git a/configs/_base_/datasets/breaking_bad/artifact.py b/configs/_base_/datasets/breaking_bad/artifact.py index 5b867e9..f7f0e35 100644 --- a/configs/_base_/datasets/breaking_bad/artifact.py +++ b/configs/_base_/datasets/breaking_bad/artifact.py @@ -4,7 +4,7 @@ _C = CN() _C.dataset = 'geometry' -_C.data_dir = '/scratch/ssd004/scratch/ziyiwu/data/assembly' +_C.data_dir = './data/breaking_bad' _C.data_fn = 'artifact.{}.txt' _C.data_keys = ('part_ids', ) _C.category = '' # empty means all categories diff --git a/configs/_base_/datasets/breaking_bad/everyday.py b/configs/_base_/datasets/breaking_bad/everyday.py index 38a051f..45ab3b2 100644 --- a/configs/_base_/datasets/breaking_bad/everyday.py +++ b/configs/_base_/datasets/breaking_bad/everyday.py @@ -4,7 +4,7 @@ _C = CN() _C.dataset = 'geometry' -_C.data_dir = '/scratch/ssd004/scratch/ziyiwu/data/assembly' +_C.data_dir = './data/breaking_bad' _C.data_fn = 'everyday.{}.txt' _C.data_keys = ('part_ids', ) _C.category = '' # empty means all categories diff --git a/configs/_base_/datasets/breaking_bad/other.py b/configs/_base_/datasets/breaking_bad/other.py index 8b5bdae..a35451a 100644 --- a/configs/_base_/datasets/breaking_bad/other.py +++ b/configs/_base_/datasets/breaking_bad/other.py @@ -4,7 +4,7 @@ _C = CN() _C.dataset = 'geometry' -_C.data_dir = '/scratch/ssd004/scratch/ziyiwu/data/assembly' +_C.data_dir = './data/breaking_bad' _C.data_fn = 'other.{}.txt' _C.data_keys = ('part_ids', ) _C.category = '' # empty means all categories diff --git a/configs/_base_/datasets/partnet/partnet_chair.py b/configs/_base_/datasets/partnet/partnet_chair.py index b0378f6..2b0a485 100644 --- a/configs/_base_/datasets/partnet/partnet_chair.py +++ b/configs/_base_/datasets/partnet/partnet_chair.py @@ -4,7 +4,7 @@ _C = CN() _C.dataset = 'partnet' -_C.data_dir = '../Generative-3D-Part-Assembly/prepare_data' +_C.data_dir = './data/partnet' _C.data_fn = 'Chair.{}.npy' _C.category = 'Chair' # actually useless _C.data_keys = ('part_ids', 'match_ids', 'contact_points') diff --git a/configs/_base_/datasets/partnet/partnet_lamp.py b/configs/_base_/datasets/partnet/partnet_lamp.py index 1c42208..4890b87 100644 --- a/configs/_base_/datasets/partnet/partnet_lamp.py +++ b/configs/_base_/datasets/partnet/partnet_lamp.py @@ -4,7 +4,7 @@ _C = CN() _C.dataset = 'partnet' -_C.data_dir = '../Generative-3D-Part-Assembly/prepare_data' +_C.data_dir = './data/partnet' _C.data_fn = 'Lamp.{}.npy' _C.category = 'Lamp' # actually useless _C.data_keys = ('part_ids', 'match_ids', 'contact_points') diff --git a/configs/_base_/datasets/partnet/partnet_table.py b/configs/_base_/datasets/partnet/partnet_table.py index 3725585..760810a 100644 --- a/configs/_base_/datasets/partnet/partnet_table.py +++ b/configs/_base_/datasets/partnet/partnet_table.py @@ -4,7 +4,7 @@ _C = CN() _C.dataset = 'partnet' -_C.data_dir = '../Generative-3D-Part-Assembly/prepare_data' +_C.data_dir = './data/partnet' _C.data_fn = 'Table.{}.npy' _C.category = 'Table' # actually useless _C.data_keys = ('part_ids', 'match_ids', 'contact_points') diff --git a/configs/_base_/models/loss/geometric_loss.py b/configs/_base_/models/loss/geometric_loss.py index 9a382a8..6c8a36e 100644 --- a/configs/_base_/models/loss/geometric_loss.py +++ b/configs/_base_/models/loss/geometric_loss.py @@ -15,7 +15,6 @@ # also note that there is almost no symmetry in this dataset _C = CN() _C.noise_dim = 0 # no stochastic -_C.num_rot = 1 # rotate GT to match the predictions _C.trans_loss_w = 1. _C.rot_pt_cd_loss_w = 10. diff --git a/configs/_base_/models/pn_transformer/pn_transformer_gan.py b/configs/_base_/models/pn_transformer/pn_transformer_gan.py deleted file mode 100644 index 355bc53..0000000 --- a/configs/_base_/models/pn_transformer/pn_transformer_gan.py +++ /dev/null @@ -1,23 +0,0 @@ -"""PointNet-Transformer model with adversarial loss.""" - -from yacs.config import CfgNode as CN - -_C = CN() -_C.name = 'pn_transformer_gan' -_C.rot_type = 'quat' -_C.pc_feat_dim = 256 - -_C.encoder = 'pointnet' # 'dgcnn', 'pointnet2_ssg', 'pointnet2_msg' - -_C.transformer_feat_dim = 1024 -_C.transformer_heads = 8 -_C.transformer_layers = 4 -_C.transformer_pre_ln = True - -_C.discriminator = 'pointnet' # encoder used in the shape discriminator -_C.discriminator_num_points = 1024 -_C.discriminator_loss = 'mse' # 'ce' - - -def get_cfg_defaults(): - return _C.clone() diff --git a/configs/_base_/models/pn_transformer/vn_pn_transformer.py b/configs/_base_/models/pn_transformer/vn_pn_transformer.py deleted file mode 100644 index c74cd86..0000000 --- a/configs/_base_/models/pn_transformer/vn_pn_transformer.py +++ /dev/null @@ -1,17 +0,0 @@ -"""VN-PointNet-Transformer model.""" - -from yacs.config import CfgNode as CN - -_C = CN() -_C.name = 'vn_pn_transformer' -_C.rot_type = 'rmat' -_C.pc_feat_dim = 48 # use a smaller one because VN feature is 3xC - -_C.encoder = 'vn-pointnet' - -_C.transformer_heads = 4 -_C.transformer_layers = 2 - - -def get_cfg_defaults(): - return _C.clone() diff --git a/configs/_base_/models/pn_transformer/vn_pn_transformer_v2.py b/configs/_base_/models/pn_transformer/vn_pn_transformer_v2.py deleted file mode 100644 index 72ef875..0000000 --- a/configs/_base_/models/pn_transformer/vn_pn_transformer_v2.py +++ /dev/null @@ -1,18 +0,0 @@ -"""VN-PointNet-Transformer model.""" - -from yacs.config import CfgNode as CN - -_C = CN() -_C.name = 'vn_pn_transformer_v2' -_C.rot_type = 'rmat' -_C.pc_feat_dim = 48 # use a smaller one because VN feature is 3xC -_C.model.rmat_can = False # use a rmat to canonicalize the part feature - -_C.encoder = 'vn-pointnet' - -_C.transformer_heads = 4 -_C.transformer_layers = 2 - - -def get_cfg_defaults(): - return _C.clone() diff --git a/configs/dgl/dgl-32x1-cosine_200e-artifact.py b/configs/dgl/dgl-32x1-cosine_200e-artifact.py index cc6f719..cd52424 100644 --- a/configs/dgl/dgl-32x1-cosine_200e-artifact.py +++ b/configs/dgl/dgl-32x1-cosine_200e-artifact.py @@ -20,7 +20,7 @@ _C.exp.val_every = 5 # DGL training is very slow _C.data = CN() -_C.data.data_keys = ('part_ids', 'instance_label', 'valid_matrix') +_C.data.data_keys = ('part_ids', 'valid_matrix') def get_cfg_defaults(): diff --git a/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-artifact.py b/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-artifact.py deleted file mode 100644 index 7af64f0..0000000 --- a/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-artifact.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -from yacs.config import CfgNode as CN -from multi_part_assembly.utils import merge_cfg - -_base_ = { - 'exp': '../../_base_/default_exp.py', - 'data': '../../_base_/datasets/breaking_bad/artifact.py', - 'optimizer': '../../_base_/schedules/adam_cosine.py', - 'model': '../../_base_/models/pn_transformer/pn_transformer_gan.py', - 'loss': '../../_base_/models/loss/geometric_loss.py', -} - -# Miscellaneous configs -_C = CN() - -_C.exp = CN() -_C.exp.num_epochs = 400 - -_C.optimizer = CN() -_C.optimizer.d_lr = 1e-3 -_C.optimizer.warmup_ratio = 0.05 - -_C.loss = CN() -_C.loss.g_loss_w = 1. -_C.loss.d_loss_w = 1. - - -def get_cfg_defaults(): - base_cfg = _C.clone() - cfg = merge_cfg(base_cfg, os.path.dirname(__file__), _base_) - return cfg diff --git a/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-everyday.py b/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-everyday.py deleted file mode 100644 index 5b34aa2..0000000 --- a/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-everyday.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -from yacs.config import CfgNode as CN -from multi_part_assembly.utils import merge_cfg - -_base_ = { - 'exp': '../../_base_/default_exp.py', - 'data': '../../_base_/datasets/breaking_bad/everyday.py', - 'optimizer': '../../_base_/schedules/adam_cosine.py', - 'model': '../../_base_/models/pn_transformer/pn_transformer_gan.py', - 'loss': '../../_base_/models/loss/geometric_loss.py', -} - -# Miscellaneous configs -_C = CN() - -_C.exp = CN() -_C.exp.num_epochs = 400 - -_C.optimizer = CN() -_C.optimizer.d_lr = 1e-3 -_C.optimizer.warmup_ratio = 0.05 - -_C.loss = CN() -_C.loss.g_loss_w = 1. -_C.loss.d_loss_w = 1. - - -def get_cfg_defaults(): - base_cfg = _C.clone() - cfg = merge_cfg(base_cfg, os.path.dirname(__file__), _base_) - return cfg diff --git a/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-partnet_chair.py b/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-partnet_chair.py deleted file mode 100644 index acce70e..0000000 --- a/configs/pn_transformer/pn_transformer_gan/pn_transformer_gan-32x1-cosine_400e-partnet_chair.py +++ /dev/null @@ -1,31 +0,0 @@ -import os -from yacs.config import CfgNode as CN -from multi_part_assembly.utils import merge_cfg - -_base_ = { - 'exp': '../../_base_/default_exp.py', - 'data': '../../_base_/datasets/partnet/partnet_chair.py', - 'optimizer': '../../_base_/schedules/adam_cosine.py', - 'model': '../../_base_/models/pn_transformer/pn_transformer_gan.py', - 'loss': '../../_base_/models/loss/semantic_loss.py', -} - -# Miscellaneous configs -_C = CN() - -_C.exp = CN() -_C.exp.num_epochs = 400 - -_C.optimizer = CN() -_C.optimizer.d_lr = 1e-3 -_C.optimizer.warmup_ratio = 0.05 - -_C.loss = CN() -_C.loss.g_loss_w = 1. -_C.loss.d_loss_w = 1. - - -def get_cfg_defaults(): - base_cfg = _C.clone() - cfg = merge_cfg(base_cfg, os.path.dirname(__file__), _base_) - return cfg diff --git a/configs/pn_transformer/vn_pn_transformer/vn_pn_transformer-6x1-cosine_400e-artifact.py b/configs/pn_transformer/vn_pn_transformer/vn_pn_transformer-6x1-cosine_400e-artifact.py deleted file mode 100644 index 0a03c5f..0000000 --- a/configs/pn_transformer/vn_pn_transformer/vn_pn_transformer-6x1-cosine_400e-artifact.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -from yacs.config import CfgNode as CN -from multi_part_assembly.utils import merge_cfg - -_base_ = { - 'exp': '../../_base_/default_exp.py', - 'data': '../../_base_/datasets/breaking_bad/artifact.py', - 'optimizer': '../../_base_/schedules/adam_cosine.py', - 'model': '../../_base_/models/pn_transformer/vn_pn_transformer.py', - 'loss': '../../_base_/models/loss/geometric_loss.py', -} - -# Miscellaneous configs -_C = CN() - -_C.exp = CN() -_C.exp.num_epochs = 400 -_C.exp.batch_size = 6 # GPU memory limit on RTX6000 with 24GB memory -_C.exp.num_workers = 6 - -_C.optimizer = CN() -_C.optimizer.warmup_ratio = 0.05 - - -def get_cfg_defaults(): - base_cfg = _C.clone() - cfg = merge_cfg(base_cfg, os.path.dirname(__file__), _base_) - return cfg diff --git a/configs/pn_transformer/vn_pn_transformer/vn_pn_transformer-6x1-cosine_400e-everyday.py b/configs/pn_transformer/vn_pn_transformer/vn_pn_transformer-6x1-cosine_400e-everyday.py deleted file mode 100644 index 9145313..0000000 --- a/configs/pn_transformer/vn_pn_transformer/vn_pn_transformer-6x1-cosine_400e-everyday.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -from yacs.config import CfgNode as CN -from multi_part_assembly.utils import merge_cfg - -_base_ = { - 'exp': '../../_base_/default_exp.py', - 'data': '../../_base_/datasets/breaking_bad/everyday.py', - 'optimizer': '../../_base_/schedules/adam_cosine.py', - 'model': '../../_base_/models/pn_transformer/vn_pn_transformer.py', - 'loss': '../../_base_/models/loss/geometric_loss.py', -} - -# Miscellaneous configs -_C = CN() - -_C.exp = CN() -_C.exp.num_epochs = 400 -_C.exp.batch_size = 6 # GPU memory limit on RTX6000 with 24GB memory -_C.exp.num_workers = 6 - -_C.optimizer = CN() -_C.optimizer.warmup_ratio = 0.05 - - -def get_cfg_defaults(): - base_cfg = _C.clone() - cfg = merge_cfg(base_cfg, os.path.dirname(__file__), _base_) - return cfg diff --git a/configs/pn_transformer/vn_pn_transformer_v2/vn_pn_transformer_v2-6x1-cosine_400e-everyday.py b/configs/pn_transformer/vn_pn_transformer_v2/vn_pn_transformer_v2-6x1-cosine_400e-everyday.py deleted file mode 100644 index b9d61cf..0000000 --- a/configs/pn_transformer/vn_pn_transformer_v2/vn_pn_transformer_v2-6x1-cosine_400e-everyday.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -from yacs.config import CfgNode as CN -from multi_part_assembly.utils import merge_cfg - -_base_ = { - 'exp': '../../_base_/default_exp.py', - 'data': '../../_base_/datasets/breaking_bad/everyday.py', - 'optimizer': '../../_base_/schedules/adam_cosine.py', - 'model': '../../_base_/models/pn_transformer/vn_pn_transformer_v2.py', - 'loss': '../../_base_/models/loss/geometric_loss.py', -} - -# Miscellaneous configs -_C = CN() - -_C.exp = CN() -_C.exp.num_epochs = 400 -_C.exp.batch_size = 4 # GPU memory limit on RTX6000 with 24GB memory -_C.exp.num_workers = 4 - -_C.optimizer = CN() -_C.optimizer.warmup_ratio = 0.05 - - -def get_cfg_defaults(): - base_cfg = _C.clone() - cfg = merge_cfg(base_cfg, os.path.dirname(__file__), _base_) - return cfg diff --git a/configs/rgl_net/rgl_net-32x1-cosine_200e-artifact.py b/configs/rgl_net/rgl_net-32x1-cosine_200e-artifact.py index 0f58230..e7bf093 100644 --- a/configs/rgl_net/rgl_net-32x1-cosine_200e-artifact.py +++ b/configs/rgl_net/rgl_net-32x1-cosine_200e-artifact.py @@ -17,7 +17,7 @@ _C.exp.val_every = 5 # to be the same as DGL _C.data = CN() -_C.data.data_keys = ('part_ids', 'instance_label', 'valid_matrix') +_C.data.data_keys = ('part_ids', 'valid_matrix') def get_cfg_defaults(): diff --git a/docs/install.md b/docs/install.md index 200be14..1bd8b0a 100644 --- a/docs/install.md +++ b/docs/install.md @@ -11,6 +11,8 @@ conda create -n assembly python=3.8 conda activate assembly # pytorch conda install pytorch=1.10 torchvision torchaudio cudatoolkit=11.3 -c pytorch +# pytorch-lightning +conda install pytorch-lightning=1.6.2 # pytorch3d conda install -c fvcore -c iopath -c conda-forge fvcore iopath conda install pytorch3d -c pytorch3d diff --git a/docs/model.md b/docs/model.md new file mode 100644 index 0000000..f99c5e0 --- /dev/null +++ b/docs/model.md @@ -0,0 +1,119 @@ +# Supported Models + +All assembly models follow a similar pipeline: + +1. A point cloud encoder extracts features from each input part. + We support common encoders such as PointNet, PointNet++ and DGCNN +2. A correlation module performs relation reasoning between part features, which can be LSTM, GNN, and Transformer +3. A MLP-based PoseRegressor predicts rotation and translation for each part + +## Model Details + +Below we briefly describe methods we implement in this repo: + +### Global + +A naive baseline from [DGL](https://arxiv.org/pdf/2006.07793.pdf). +This model concatenates all part point clouds and extract a _global feature_. +Then, it concatenates the global feature with each part feature as the input to the pose regressor. + +### LSTM + +A naive baseline from [DGL](https://arxiv.org/pdf/2006.07793.pdf). +This model applies a Bidirectional-LSTM over part features for reasoning, and use the LSTM output for pose prediction. + +Note that, in PartNet the order of parts in pre-processed data follow some patterns (e.g. from chair leg to seat to back), which causes information leak if using LSTM. +Therefore, we need to shuffle the order of parts in training data. + +### DGL (NeurIPS'20) + +Proposed in [DGL](https://arxiv.org/pdf/2006.07793.pdf). +This model leverages a GNN to perform message passing and interactions between parts. +Besides, it adopts an iterative refinement process. +The model first outputs a rough prediction given initial input parts. +Then, it applies the predicted transformation to each part, and runs the model on the transformed parts to predict a _residual_ transformation. +DGL repeats this process for several (3 by default) times, thus refining the prediction to get a good result. + +### RGL-NET (WACV'22) + +Proposed in [RGL-NET](https://arxiv.org/pdf/2107.12859.pdf). +Intuitively, RGL-NET is a combination of DGL and LSTM. +It applies both the GNN and Bidirectional-LSTM to reason part relations. +It assumes the input parts follow some orders (not necessarily need GT part labels, can also be e.g. part volumes, see Table 4 in their paper). + +We do not implement the input sorting operation. +This is because on the one hand, the pre-processed PartNet data does follow some partterns, and we observe that RGL-NET can indeed leverage such partterns. +On the other hand, in geometric assembly, there are no semantically meaningful order of parts. +Indeed, in our Breaking Bad benchmark, RGL-NET performs similarly to DGL, so we do not include it in our paper. + +### Transformer-based Methods (our designed) + +This class of methods simply replace the GNN with a standard TransformerEncoder to learn part interactions. +We also provide a variant that adopts the iterative refinement process as in DGL. + +**Remark**: We implement two additional Transformer-based models which is further discussed in the `dev` branch. + +## Benchmarks + +### Semantic Assembly + +- Results on PartNet chair: + +| Method | Shape Chamfer (SCD) ↓ | Part Accuracy (%) ↑ | Connectivity Accuracy (%) | +| :---------------------------------------------------------------------------------------------------------------------------: | :-------------------: | :-----------------: | :-----------------------: | +| [Global](../configs/global/global-32x1-cosine_200e-partnet_chair.py) | 0.0128 | 23.82 | 16.29 | +| [LSTM](../configs/lstm/lstm-32x1-cosine_200e-partnet_chair.py) | 0.0114 | 22.03 | 14.88 | +| [DGL](../configs/dgl/dgl-32x1-cosine_300e-partnet_chair.py) | 0.0079 | 40.56 | 27.58 | +| [RGL-NET](../configs/rgl_net/rgl_net-32x1-cosine_300e-partnet_chair.py) | 0.0068 | 44.24 | 29.38 | +| [Transformer](../configs/pn_transformer/pn_transformer/pn_transformer-32x1-cosine_400e-partnet_chair.py) | 0.0089 | 41.90 | 29.11 | +| [Refine-Transformer](../configs/pn_transformer/pn_transformer_refine/pn_transformer_refine-32x1-cosine_400e-partnet_chair.py) | 0.0079 | 42.97 | 31.25 | + +See [wandb report](https://wandb.ai/dazitu616/Multi-Part-Assembly/reports/Benchmark-on-PartNet-Chair-Assembly--VmlldzoyNzI0NTg5?accessToken=zhov8augcax9ud8rvwemv3k9n120i2hvnjiskms6o2nx1esd3xkz8o18l55ugxhv) for detailed training logs. + +To reproduce the result, take DGL for example, simply run: + +``` +GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=4 QOS=normal REPEAT=1 ./scripts/dup_run_sbatch_ddl.sh $PARTITION dgl-32x1-cosine_300e-partnet_chair scripts/train.py configs/dgl/dgl-32x1-cosine_300e-partnet_chair.py --fp16 --cudnn +``` + +Then, you can go to wandb to find the results. + +### Geometric Assembly + +- Results on Breaking Bad Dataset: see our [paper](https://openreview.net/forum?id=mJWt6pOcHNy) + +**To reproduce our main results on the everyday subset (paper Table 3)**, take DGL for example, please run: + +``` +./scripts/train_everyday_categories.sh "GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=4 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh $PARTITION dgl-32x1-cosine_200e-everyday-CATEGORY ./scripts/train.py configs/dgl/dgl-32x1-cosine_200e-everyday.py --fp16 --cudnn" configs/dgl/dgl-32x1-cosine_200e-everyday.py +``` + +- This assumes you are working on a slurm-based computing cluster. + If you work on servers then you will need to manually train the model on all categories. +- In Table 4, we train one model per category, and report the numbers averaged over all categories. +- Since some categories have only a few base shapes, the results may vary among different runs. + Therefore, we run all the experiments 3 times and report the average results. + You can modify the `REPEAT=3` flag above for your need. + +After running the above script, the model weights will be saved in `checkpoint/dgl-32x1-cosine_200e-everyday-$CATEGORY-dup$X`, where `$CATEGORY` is the category (e.g. Bottle, Teapot), and `X` indexes different runs. +To collect the results, run (assuming you are in a GPU environment): + +``` +python scripts/collect_test.py --cfg_file configs/dgl/dgl-32x1-cosine_200e-everyday.py --num_dup 3 --ckp_suffix checkpoint/dgl-32x1-cosine_200e-everyday- +``` + +It will automatically test each model and collect its evaluation metrics, doing the calculation, and format them into **LaTeX** format, which you can directly copy paste to your table. + +**To reproduce our ablation study results (paper Table 4)**, you need to create new config files for each model, and set the `_C.data.max_num_part` to the number you want to try. +Then, you can train the model in the same way as detailed above. + +To collect the results, again you can use the `scripts/collect_test.py` script. +To control the number of pieces to test, you can set the `--min_num_part` and `--max_num_part` flags. + +**To reproduce our results in the appendix (Table 11 bottom)**, i.e. train one model on all the categories, simply run: + +``` +GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=4 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh $PARTITION dgl-32x1-cosine_200e-everyday scripts/train.py configs/dgl/dgl-32x1-cosine_200e-everyday.py --fp16 --cudnn +``` + +Then, you can use the same script to collect the results as detailed above (add a `--train_all` flag because the model is trained on all categories jointly). diff --git a/docs/tutorial.md b/docs/tutorial.md index 960b875..acff582 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -22,11 +22,14 @@ For other data items, see comments in the [dataset files](../multi_part_assembly ## Model Shape assembly models usually consist of a point cloud feature extractor (e.g. PointNet), a relationship reasoning module (e.g. GNNs), and a pose predictor (usually implemented as MLPs). +See [model](./model.md) for details about the baselines supported in this codebase. ### Base Model We implement a `BaseModel` class as an instance of PyTorch-Lightning's `LightningModule`, which support general methods such as `training/validation/test_step/epoch_end()`. It also implements general loss computation, metrics calculation, and visualization during training. +See [base_model.py](../multi_part_assembly/models/modules/base_model.py). +Below we detail some core methods we implement for all assembly models. ### Assembly Models @@ -68,9 +71,10 @@ See `_match_parts()` method of `BaseModel` class. ### Geometric Assembly Usually, there is no geometrically equivalent parts in this setting. -However, sometimes it is hard to define a canonical pose for objects due to symmetry. -Therefore, we develop a rotation matching step to minimize the loss. -We rotate the ground-truth object along Z-axis for different angles and select one as the new ground-truth. +So we don't need to perform the matching GT step. + +**Remark**: It is actually very hard to define a _canonical_ pose for objects under the geometric assembly setting, due to e.g. symmetry of a bottle/vase. +See `dev` branch for our experimental features in solving this issue. ## Metrics @@ -80,15 +84,8 @@ Please refer to Section 4.3 of the [paper](https://arxiv.org/pdf/2006.07793.pdf) For geometric assembly, we adopt SCD and PA, as well as MSE/RMSE/MAE between translations and rotations. Please refer to Section 6.1 of the [paper](https://arxiv.org/pdf/2205.14886.pdf) for more details. -**Experimental**: - -- As pointed out by some papers (e.g. [this](https://www.cs.cmu.edu/~cga/dynopt/readings/Rmetric.pdf)), MSE between rotations is not a good metric. - Therefore, we adopt the geodesic distance between two rotations as another metric. - See `rot_geodesic_dist()` function in [eval_utils.py](../multi_part_assembly/utils/eval_utils.py). -- For objects without a clear canonical pose, we compute the relative pose errors. - Suppose there are 10 parts in a shape. - Every time we treat 1 part as canonical, and calculate the relative poses between the other 9 parts to it. - We repeat this process for 10 times, and take the min error as the final result. +**Remark**: As discussed above, these metrics are sometimes problematic due to the symmetry ambiguity. +See `dev` branch for experimental metrics that are robust under this setting. ## Rotation Representation @@ -97,8 +94,9 @@ Please refer to Section 6.1 of the [paper](https://arxiv.org/pdf/2205.14886.pdf) - For ease of data batching, we always represent rotations as quaternions from the dataloaders. However, to build a compatible interface for util functions, model input-output, we wrap the predicted rotations in a `Rotation3D` class, which supports common format conversion and tensor operations. See [rotation.py](../multi_part_assembly/utils/rotation.py) for detailed definitions -- Other rotation representation we support: - - 6D representation (rotation matrix): see CVPR'19 [paper](https://zhouyisjtu.github.io/project_rotation/rotation.html). +- Rotation representations we support (change `_C.rot_type` under `model` field to use different rotation representations): + - Quaternion (`quat`), by default + - 6D representation (rotation matrix, `rmat`): see CVPR'19 [paper](https://zhouyisjtu.github.io/project_rotation/rotation.html). The predicted `6`-len tensor will be reshaped to `(2, 3)`, and the third row is obtained via cross product. Then, the 3 vectors will be stacked along the `-2`-th dim. In a `Rotation3D` object, the 6D representation will be converted to a 3x3 rotation matrix diff --git a/docs/usage.md b/docs/usage.md index 729b373..958096d 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -2,6 +2,10 @@ ## Training +**Please go throught the `scripts/train.py` file and modify some training configurations according to items marked by `TODO`**, e.g. disable DDP training, handle cluster/slurm-related settings. + +See [model](./model.md) for details about the baselines supported in this codebase. + To train a model, simply run: ``` @@ -12,7 +16,7 @@ For example, to train the Global baseline model on PartNet chair, replace `$CFG` Other optional arguments include: - `--category`: train the model only on a subset of data, e.g. `Chair`, `Table`, `Lamp` on PartNet -- `--gpus`: setting training GPUs, note that by default we are using DP training. Please modify `scripts/train.py` to enable DDP training +- `--gpus`: setting training GPUs, note that by default we are using DDP supported by PyTorch-Lightning - `--weight`: loading pre-trained weights - `--fp16`: FP16 mixed precision training - `--cudnn`: setting `cudnn.benchmark = True` @@ -21,20 +25,20 @@ Other optional arguments include: ### Logging We use [wandb](https://wandb.ai/site) for logging. -Please set up your account on the machine before running training commands. +Please set up your account with `wandb login` on the machine before running training commands. ### Helper Scripts Script for configuring and submitting jobs to cluster SLURM system: ``` -GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal ./scripts/sbatch_run.sh $PARTITION $JOB_NAME ./scripts/train.py --cfg_file $CFG --other_args... +GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal ./scripts/sbatch_run.sh $PARTITION $JOB_NAME ./scripts/train.py --cfg_file $CFG --other_args... ``` Script for running a job multiple times over different random seeds: ``` -GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=$NUM_REPEAT ./scripts/dup_run_sbatch.sh $PARTITION $JOB_NAME ./scripts/train.py $CFG --other_args... +GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=$NUM_REPEAT ./scripts/dup_run_sbatch.sh $PARTITION $JOB_NAME ./scripts/train.py $CFG --other_args... ``` We also provide scripts for training on single/all categories of the Breaking-Bad dataset's `everyday` subset. @@ -70,7 +74,7 @@ python scripts/collect_test.py --cfg_file $CFG.py --num_dup $X --ckp_suffix chec The per-category results will be formatted into latex table style for the ease of paper writing. -Besides, if you train the models on all categories by running `GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=$NUM_REPEAT ./scripts/dup_run_sbatch.sh $PARTITION $JOB_NAME ./scripts/train.py $CFG --other_args...`. Then the model checkpoint will be saved in `checkpoint/$CFG-dup$X`. To collect the performance, simply adding a `--train_all` flag: +Besides, if you train the models on all categories by running `GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=$NUM_REPEAT ./scripts/dup_run_sbatch.sh $PARTITION $JOB_NAME ./scripts/train.py $CFG --other_args...`. Then the model checkpoint will be saved in `checkpoint/$CFG-dup$X`. To collect the performance, simply adding a `--train_all` flag: ``` python scripts/collect_test.py --cfg_file $CFG.py --num_dup $X --ckp_suffix checkpoint/$CFG- --train_all diff --git a/multi_part_assembly/models/__init__.py b/multi_part_assembly/models/__init__.py index 9c54313..f931fd0 100644 --- a/multi_part_assembly/models/__init__.py +++ b/multi_part_assembly/models/__init__.py @@ -1,6 +1,5 @@ from .modules import * -from .pn_transformer import PNTransformer, PNTransformerGAN, \ - PNTransformerRefine, VNPNTransformer, VNPNTransformerV2 +from .pn_transformer import PNTransformer, PNTransformerRefine from .b_identity import IdentityModel from .b_global import GlobalModel from .b_lstm import LSTMModel @@ -21,13 +20,7 @@ def build_model(cfg): return RGLNet(cfg) elif cfg.model.name == 'pn_transformer': return PNTransformer(cfg) - elif cfg.model.name == 'pn_transformer_gan': - return PNTransformerGAN(cfg) elif cfg.model.name == 'pn_transformer_refine': return PNTransformerRefine(cfg) - elif cfg.model.name == 'vn_pn_transformer': - return VNPNTransformer(cfg) - elif cfg.model.name == 'vn_pn_transformer_v2': - return VNPNTransformerV2(cfg) else: raise NotImplementedError(f'Model {cfg.model.name} not supported') diff --git a/multi_part_assembly/models/modules/__init__.py b/multi_part_assembly/models/modules/__init__.py index dde8312..fcc3740 100644 --- a/multi_part_assembly/models/modules/__init__.py +++ b/multi_part_assembly/models/modules/__init__.py @@ -1,5 +1,4 @@ from .rnn import RNNWrapper -from .vnn import * from .encoder import * -from .regressor import PoseRegressor, StocasticPoseRegressor, VNPoseRegressor +from .regressor import PoseRegressor, StocasticPoseRegressor from .base_model import BaseModel diff --git a/multi_part_assembly/models/modules/base_model.py b/multi_part_assembly/models/modules/base_model.py index 723207c..5d34c8f 100644 --- a/multi_part_assembly/models/modules/base_model.py +++ b/multi_part_assembly/models/modules/base_model.py @@ -4,14 +4,13 @@ import numpy as np from scipy.optimize import linear_sum_assignment -from scipy.spatial.transform import Rotation as R -from multi_part_assembly.utils import rot_pc, transform_pc, Rotation3D +from multi_part_assembly.utils import transform_pc, Rotation3D from multi_part_assembly.utils import colorize_part_pc, filter_wd_parameters from multi_part_assembly.utils import trans_l2_loss, rot_points_cd_loss, \ shape_cd_loss, rot_cosine_loss, rot_points_l2_loss, chamfer_distance from multi_part_assembly.utils import calc_part_acc, calc_connectivity_acc, \ - trans_metrics, rot_metrics, rot_geodesic_dist, relative_pose_metrics + trans_metrics, rot_metrics from multi_part_assembly.utils import CosineAnnealingWarmupRestarts @@ -53,7 +52,6 @@ def _setup(self): # loss configs self.sample_iter = self.cfg.loss.get('sample_iter', 1) - self.num_rot = self.cfg.loss.get('num_rot', 1) def forward(self, data_dict): """Forward pass to predict poses for each part.""" @@ -136,7 +134,7 @@ def forward_pass(self, data_dict, mode, optimizer_idx): loss_dict = self.loss_function(data_dict, optimizer_idx=optimizer_idx) # in training we log for every step - if mode == 'train': + if mode == 'train' and self.local_rank == 0: log_dict = {f'{mode}/{k}': v.item() for k, v in loss_dict.items()} data_name = [ k for k in self.trainer.profiler.recorded_durations.keys() @@ -144,7 +142,8 @@ def forward_pass(self, data_dict, mode, optimizer_idx): ][0] log_dict[f'{mode}/data_time'] = \ self.trainer.profiler.recorded_durations[data_name][-1] - self.log_dict(log_dict, logger=True, sync_dist=False) + self.log_dict( + log_dict, logger=True, sync_dist=False, rank_zero_only=True) return loss_dict @@ -238,48 +237,6 @@ def _match_parts(self, part_pcs, pred_trans, pred_rot, gt_trans, gt_rot, new_gt_rot = self._wrap_rotation(new_gt_rot_tensor) return new_gt_trans, new_gt_rot - @torch.no_grad() - def _match_rotation(self, pred_trans, pred_rot, gt_trans, gt_rot, valids): - """Used in geometric assembly. Match GT to predictions. - - Since objects in geometric assembly are often symmetric, we rotate the - GT and match the prediction. We use trans MSE as the criterion. - - Args: - pred/gt_trans: [B, P, 3] - pred/gt_rot: [B, P, 4/(3, 3)], Rotation3D, quat or rmat - valids: [B, P], 1 for input parts, 0 for padded parts - - Returns: - GT poses after rearrangement - """ - if self.num_rot == 1: - return gt_trans.detach().clone(), gt_rot.detach().clone() - P = pred_trans.shape[1] - # uniform rotation along z-axis - if not hasattr(self, '_uniform_z_rot'): - z_angles = 360. / self.num_rot * np.arange(self.num_rot) - z_rot = [ - R.from_euler('z', angle, degrees=True).as_matrix() - for angle in z_angles - ] - self._uniform_z_rot = torch.from_numpy(np.stack(z_rot, 0))[None] - z_rot = self._uniform_z_rot.type_as(gt_trans) # [1, n, 3, 3] - # rotate `gt_trans`, [B, n, P, 3] - rot_gt_trans = (z_rot.unsqueeze(2) - @ gt_trans.unsqueeze(1).unsqueeze(-1)).squeeze(-1) - trans_loss = (rot_gt_trans - pred_trans.unsqueeze(1)).pow(2).sum(-1) - valids = valids.unsqueeze(1).float() - trans_loss = (trans_loss * valids).sum(-1) / valids.sum(-1) # [B, n] - # take the min rotation for each data in the batch - min_idx = trans_loss.argmin(1) # [B] - min_z_rot = z_rot[0][min_idx].unsqueeze(1).repeat(1, P, 1, 1) - min_z_rot = Rotation3D(min_z_rot, rot_type='rmat') - # rotate the GTs - new_gt_trans = rot_pc(min_z_rot, gt_trans) - new_gt_rot = gt_rot.apply_rotation(min_z_rot) - return new_gt_trans, new_gt_rot - def _calc_loss(self, out_dict, data_dict): """Calculate loss by matching GT to prediction. @@ -295,10 +252,10 @@ def _calc_loss(self, out_dict, data_dict): new_trans, new_rot = self._match_parts(part_pcs, pred_trans, pred_rot, gt_trans, gt_rot, match_ids) - # rotate the object for lowest translation MSE in geometric assembly + # do nothing in geometric assembly else: - new_trans, new_rot = self._match_rotation(pred_trans, pred_rot, - gt_trans, gt_rot, valids) + new_trans, new_rot = \ + gt_trans.detach().clone(), gt_rot.detach().clone() # computing loss trans_loss = trans_l2_loss(pred_trans, new_trans, valids) @@ -312,7 +269,19 @@ def _calc_loss(self, out_dict, data_dict): new_rot, valids, ret_pts=True, - training=self.training, + training=self.semantic or self.training, + # TODO: divide the SCD loss by the real number of parts (False) or + # TODO: a fixed padding number (e.g. 20 in PartNet) (True) + # In semantic assembly, we follow DGL to divide by padding number. + # During training, it serves as hard negative mining; while it's + # also valid during testing because all the shapes have the same + # `max_num_part` value. So we always set `training=True` here. + # In geometric assembly, we do hard negative mining during training + # too, but divide SCD by the real number of parts during testing, + # which is also the results reported in the Breaking Bad paper. + # This is because the number of parts here could vary, e.g. we have + # ablation study on different number of parts (paper Table 4). + # See the docstring of this loss function for more details. ) loss_dict = { 'trans_loss': trans_loss, @@ -367,12 +336,6 @@ def _calc_metrics(self, data_dict, out_dict, gt_trans, gt_rot): pred_trans, gt_trans, valids, metric=metric) metric_dict[f'rot_{metric}'] = rot_metrics( pred_rot, gt_rot, valids, metric=metric) - metric_dict['geo_rot'] = rot_geodesic_dist(pred_rot, gt_rot, - valids) - # relative pose metrics - relative_metric_dict = relative_pose_metrics( - pred_trans, gt_trans, pred_rot, gt_rot, valids) - metric_dict.update(relative_metric_dict) return metric_dict def _loss_function(self, data_dict, out_dict={}, optimizer_idx=-1): diff --git a/multi_part_assembly/models/modules/encoder/__init__.py b/multi_part_assembly/models/modules/encoder/__init__.py index 34f3606..c317ac1 100644 --- a/multi_part_assembly/models/modules/encoder/__init__.py +++ b/multi_part_assembly/models/modules/encoder/__init__.py @@ -1,4 +1,4 @@ -from .pointnet import PointNet, VNPointNet +from .pointnet import PointNet from .dgcnn import DGCNN from .pointnet2 import PointNet2SSG, PointNet2MSG @@ -16,8 +16,6 @@ def build_encoder(arch, feat_dim, global_feat=True, **kwargs): model = PointNet2MSG(feat_dim) else: raise NotImplementedError(f'{arch} not supported') - elif arch == 'vn-pointnet': - model = VNPointNet(feat_dim, global_feat=global_feat, **kwargs) else: raise NotImplementedError(f'{arch} is not supported') return model diff --git a/multi_part_assembly/models/modules/encoder/pointnet.py b/multi_part_assembly/models/modules/encoder/pointnet.py index f6920e5..7cc57cd 100644 --- a/multi_part_assembly/models/modules/encoder/pointnet.py +++ b/multi_part_assembly/models/modules/encoder/pointnet.py @@ -2,9 +2,6 @@ import torch.nn as nn import torch.nn.functional as F -from ..vnn import VNLinear, VNBatchNorm, VNMaxPool, \ - VNLinearBNLeakyReLU - class PointNet(nn.Module): """PointNet feature extractor. @@ -42,104 +39,3 @@ def forward(self, x): else: feat = x.transpose(2, 1).contiguous() # [B, N, feat_dim] return feat - - -def knn(x, k): - """x: [B, C, N]""" - inner = -2 * torch.matmul(x.transpose(2, 1), x) - xx = torch.sum(x**2, dim=1, keepdim=True) - pairwise_distance = -xx - inner - xx.transpose(2, 1) - idx = pairwise_distance.topk(k=k, dim=-1)[1] # [B, N, k] - return idx - - -def vn_get_graph_feature(x, k=20, idx=None): - """x: [B, C, 3, N]""" - batch_size = x.size(0) - num_points = x.size(3) - x = x.view(batch_size, -1, num_points) # [B, C*3, N] - if idx is None: - idx = knn(x, k=k) # [B, N, k] - device = x.device - - idx_base = torch.arange( - 0, batch_size, device=device).view(-1, 1, 1) * num_points - - idx = idx + idx_base # [B, N, k] - - idx = idx.view(-1) - - _, num_dims, _ = x.size() - num_dims = num_dims // 3 - - x = x.transpose(2, 1).contiguous() # [B, N, C*3] - feature = x.view(batch_size * num_points, -1)[idx, :] # [B, N, k, C*3] - feature = feature.view(batch_size, num_points, k, num_dims, 3) - x = x.view(batch_size, num_points, 1, num_dims, 3).repeat(1, 1, k, 1, 1) - cross = torch.cross(feature, x, dim=-1) # [B, N, k, C, 3] - - feature = torch.cat((feature - x, x, cross), - dim=3).permute(0, 3, 4, 1, 2).contiguous() - # [B, 3C, 3, N, k] - return feature - - -def mean_pool(x, dim=-1, keepdim=False): - return x.mean(dim=dim, keepdim=keepdim) - - -class VNPointNet(nn.Module): - """VNN-based rotation Equivariant PointNet feature extractor. - - Input point clouds [B, N, 3]. - Output per-point feature [B, N, feat_dim, 3] or - global feature [B, feat_dim, 3]. - """ - - def __init__(self, feat_dim, global_feat=True, **kwargs): - super().__init__() - - self.conv1 = VNLinearBNLeakyReLU(3, 64, dim=5, negative_slope=0.0) - self.conv2 = VNLinearBNLeakyReLU(64, 64, dim=4, negative_slope=0.0) - self.conv3 = VNLinearBNLeakyReLU(64, 64, dim=4, negative_slope=0.0) - self.conv4 = VNLinearBNLeakyReLU(64, 128, dim=4, negative_slope=0.0) - self.conv5 = VNLinear(128, feat_dim, dim=4) - self.bn5 = VNBatchNorm(feat_dim, dim=4) - - pool1 = kwargs.get('pool1', 'mean') # in-knn pooling - self.pool1 = self._build_pooling(pool1, 64) - pool2 = kwargs.get('pool2', 'max') # final global_feats pooling - self.pool2 = self._build_pooling(pool2, feat_dim) - - self.global_feat = global_feat - - @staticmethod - def _build_pooling(pooling, dim=None): - if pooling == 'max': - pool = VNMaxPool(dim) - elif pooling == 'mean': - pool = mean_pool - else: - raise NotImplementedError(f'{pooling}-pooling not implemented') - return pool - - def forward(self, x): - """x: [B, N, 3]""" - x = x.transpose(2, 1).contiguous() # [B, 3, N] - - x = x.unsqueeze(1) # [B, 1, 3, N] - feat = vn_get_graph_feature(x) # [B, 3, 3, N, k] - x = self.conv1(feat) # [B, C, 3, N, k] - x = self.pool1(x) # [B, C, 3, N] - - x = self.conv2(x) - x = self.conv3(x) - x = self.conv4(x) - - x = self.bn5(self.conv5(x)) # [B, feat_dim, 3, N] - - if self.global_feat: - feat = self.pool2(x) # [B, feat_dim, 3] - else: - feat = x.permute(0, 3, 1, 2).contiguous() # [B, N, feat_dim, 3] - return feat diff --git a/multi_part_assembly/models/modules/regressor.py b/multi_part_assembly/models/modules/regressor.py index f8a82e0..9ccc5c8 100644 --- a/multi_part_assembly/models/modules/regressor.py +++ b/multi_part_assembly/models/modules/regressor.py @@ -2,8 +2,6 @@ import torch.nn as nn import torch.nn.functional as F -from .vnn import VNLinear, VNLeakyReLU, VNInFeature - def normalize_rot6d(rot): """Adopted from PyTorch3D. @@ -70,64 +68,6 @@ def forward(self, x): return rot, trans -class VNPoseRegressor(nn.Module): - """PoseRegressor for VN models. - - Target rotation should be rotation-equivariant, while target translation - should be rotation-invariant. - """ - - def __init__(self, feat_dim, rot_type='rmat', norm_rot=True): - super().__init__() - - assert rot_type == 'rmat', 'VN model only supports rotation matrix' - self.norm_rot = norm_rot - - # for rotation - self.vn_fc_layers = nn.Sequential( - VNLinear(feat_dim, 256, dim=3), - VNLeakyReLU(256, dim=3, negative_slope=0.2), - VNLinear(256, 128, dim=3), - VNLeakyReLU(128, dim=3, negative_slope=0.2), - ) - - # Rotation prediction head - # we use the 6D representation from the CVPR'19 paper - self.rot_head = VNLinear(128, 2, dim=3) # [2, 3] --> 6 - - # for translation - self.in_feats = VNInFeature(feat_dim, dim=3) - self.fc_layers = nn.Sequential( - nn.Linear(feat_dim * 3, 256), - nn.LeakyReLU(0.2), - nn.Linear(256, 128), - nn.LeakyReLU(0.2), - ) - - # Translation prediction head - self.trans_head = nn.Linear(128, 3) - - def forward(self, x): - """x: [B, C, 3] or [B, P, C, 3]""" - unflatten = len(x.shape) == 4 - B, C = x.shape[0], x.shape[-2] - x = x.view(-1, C, 3) - # rotation - rot_x = self.vn_fc_layers(x) # [N, 128, 3] - rot = self.rot_head(rot_x) # [N, 2, 3] - if self.norm_rot: - rot = normalize_rot6d(rot) # [N, 2, 3] - # translation - trans_x = self.in_feats(x).flatten(-2, -1) # [N, C*3] - trans_x = self.fc_layers(trans_x) # [N, 128] - trans = self.trans_head(trans_x) # [N, 3] - # back to [B, P] - if unflatten: - rot = rot.unflatten(0, (B, -1)) - trans = trans.unflatten(0, (B, -1)) - return rot, trans - - class StocasticPoseRegressor(PoseRegressor): """Stochastic pose regressor with noise injection.""" diff --git a/multi_part_assembly/models/modules/vnn/__init__.py b/multi_part_assembly/models/modules/vnn/__init__.py deleted file mode 100644 index 956a557..0000000 --- a/multi_part_assembly/models/modules/vnn/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .modules import VNLinear, VNBatchNorm, VNLayerNorm, VNReLU, VNLeakyReLU, \ - VNLinearBNLeakyReLU, VNMaxPool, VNInFeature, VNEqFeature -from .transformer import VNTransformerEncoderLayer, VNSelfAttention diff --git a/multi_part_assembly/models/modules/vnn/modules.py b/multi_part_assembly/models/modules/vnn/modules.py deleted file mode 100644 index 82ba8a6..0000000 --- a/multi_part_assembly/models/modules/vnn/modules.py +++ /dev/null @@ -1,316 +0,0 @@ -"""Code borrowed from: https://github.com/FlyingGiraffe/vnn-pc""" - -import torch -import torch.nn as nn - -from pytorch3d.transforms import rotation_6d_to_matrix as rot6d_to_matrix - -EPS = 1e-6 - - -def conv1x1(in_channels, out_channels, dim): - if dim == 3: - return nn.Conv1d(in_channels, out_channels, 1, bias=False) - elif dim == 4: - return nn.Conv2d(in_channels, out_channels, 1, bias=False) - elif dim == 5: - return nn.Conv3d(in_channels, out_channels, 1, bias=False) - else: - raise NotImplementedError(f'{dim}D 1x1 Conv is not supported') - - -class VNLinear(nn.Module): - - def __init__(self, in_channels, out_channels, dim): - super().__init__() - - self.map_to_feat = conv1x1(in_channels, out_channels, dim) - - def forward(self, x): - """ - Args: - x: point features of shape [B, C_in, 3, N, ...] - - Returns: - [B, C_out, 3, N, ...] - """ - x_out = self.map_to_feat(x) - return x_out - - -class VNBatchNorm(nn.Module): - - def __init__(self, num_features, dim): - super().__init__() - - if dim == 3 or dim == 4: - self.bn = nn.BatchNorm1d(num_features) - elif dim == 5: - self.bn = nn.BatchNorm2d(num_features) - else: - raise NotImplementedError(f'{dim}D is not supported') - - def forward(self, x): - """ - Args: - x: point features of shape [B, C, 3, N, ...] - - Returns: - features of the same shape after BN along C-dim - """ - norm = torch.norm(x, dim=2) + EPS - norm_bn = self.bn(norm) - norm = norm.unsqueeze(2) - norm_bn = norm_bn.unsqueeze(2) - x = x / norm * norm_bn - return x - - -class VNLeakyReLU(nn.Module): - - def __init__( - self, - in_channels, - dim, - share_nonlinearity=False, - negative_slope=0.2, - ): - super().__init__() - - if share_nonlinearity: - self.map_to_dir = conv1x1(in_channels, 1, dim=dim) - else: - self.map_to_dir = conv1x1(in_channels, in_channels, dim=dim) - self.negative_slope = negative_slope - - def forward(self, x): - """ - Args: - x: point features of shape [B, C, 3, N, ...] - - Returns: - features of the same shape after LeakyReLU - """ - d = self.map_to_dir(x) - dotprod = (x * d).sum(2, keepdim=True) - mask = (dotprod >= 0).float() - d_norm_sq = d.pow(2).sum(2, keepdim=True) - x_out = self.negative_slope * x + (1 - self.negative_slope) * ( - mask * x + (1 - mask) * (x - (dotprod / (d_norm_sq + EPS)) * d)) - return x_out - - -class VNReLU(VNLeakyReLU): - - def __init__(self, in_channels, dim, share_nonlinearity=False): - super().__init__( - in_channels, - dim=dim, - share_nonlinearity=share_nonlinearity, - negative_slope=0., - ) - - -class VNLinearBNLeakyReLU(nn.Module): - - def __init__( - self, - in_channels, - out_channels, - dim=5, - share_nonlinearity=False, - negative_slope=0.2, - ): - super().__init__() - - self.linear = VNLinear(in_channels, out_channels, dim=dim) - self.batchnorm = VNBatchNorm(out_channels, dim=dim) - self.leaky_relu = VNLeakyReLU( - out_channels, - dim=dim, - share_nonlinearity=share_nonlinearity, - negative_slope=negative_slope, - ) - - def forward(self, x): - """ - Args: - x: point features of shape [B, C_in, 3, N, ...] - - Returns: - [B, C_out, 3, N, ...] - """ - # Linear - p = self.linear(x) - # BatchNorm - p = self.batchnorm(p) - # LeakyReLU - p = self.leaky_relu(p) - return p - - -class VNMaxPool(nn.Module): - - def __init__(self, in_channels): - super().__init__() - - self.map_to_dir = conv1x1(in_channels, in_channels, dim=4) - - def forward(self, x): - """ - Args: - x: point features of shape [B, C, 3, N] - - Returns: - [B, C, 3], features after max-pooling - """ - d = self.map_to_dir(x) - dotprod = (x * d).sum(2, keepdims=True) - idx = dotprod.max(dim=-1, keepdim=False)[1] - index_tuple = torch.meshgrid([torch.arange(j) - for j in x.size()[:-1]]) + (idx, ) - x_max = x[index_tuple] - return x_max - - -class VNLayerNorm(nn.Module): - - def __init__(self, num_features): - super().__init__() - - self.ln = nn.LayerNorm(num_features) - - def forward(self, x): - """ - Args: - x: point features of shape [B, C, 3, N, ...] - - Returns: - features of the same shape after LN in each instance - """ - norm = torch.norm(x, dim=2) + EPS # [B, C, N, ...] - norm_ln = self.ln(norm.transpose(1, -1)).transpose(1, -1) - norm = norm.unsqueeze(2) - norm_ln = norm_ln.unsqueeze(2) - x = x / norm * norm_ln - return x - - -class VNInFeature(nn.Module): - """VN-Invariant layer.""" - - def __init__( - self, - in_channels, - dim=4, - share_nonlinearity=False, - negative_slope=0.2, - use_rmat=False, - ): - super().__init__() - - self.dim = dim - self.use_rmat = use_rmat - self.vn1 = VNLinearBNLeakyReLU( - in_channels, - in_channels // 2, - dim=dim, - share_nonlinearity=share_nonlinearity, - negative_slope=negative_slope, - ) - self.vn2 = VNLinearBNLeakyReLU( - in_channels // 2, - in_channels // 4, - dim=dim, - share_nonlinearity=share_nonlinearity, - negative_slope=negative_slope, - ) - self.vn_lin = conv1x1( - in_channels // 4, 2 if self.use_rmat else 3, dim=dim) - - def forward(self, x): - """ - Args: - x: point features of shape [B, C, 3, N, ...] - - Returns: - rotation invariant features of the same shape - """ - z = self.vn1(x) - z = self.vn2(z) - z = self.vn_lin(z) # [B, 3, 3, N] or [B, 2, 3, N] - if self.use_rmat: - z = z.flatten(1, 2).transpose(1, 2).contiguous() # [B, N, 6] - z = rot6d_to_matrix(z) # [B, N, 3, 3] - z = z.permute(0, 2, 3, 1) # [B, 3, 3, N] - z = z.transpose(1, 2).contiguous() - - if self.dim == 4: - x_in = torch.einsum('bijm,bjkm->bikm', x, z) - elif self.dim == 3: - x_in = torch.einsum('bij,bjk->bik', x, z) - elif self.dim == 5: - x_in = torch.einsum('bijmn,bjkmn->bikmn', x, z) - else: - raise NotImplementedError(f'dim={self.dim} is not supported') - - return x_in - - -class VNEqFeature(VNInFeature): - """Map VN-IN features back to their original rotation.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.z = None - - def forward(self, x): - """ - Args: - x: point features of shape [B, C, 3, N, ...] - - Returns: - rotation invariant features of the same shape - """ - # map to invariant - if self.z is None: - z = self.vn1(x) - z = self.vn2(z) - z = self.vn_lin(z) - if self.use_rmat: - z = z.flatten(1, 2).transpose(1, 2).contiguous() # [B, N, 6] - z = rot6d_to_matrix(z) # [B, N, 3, 3] - z = z.permute(0, 2, 3, 1) # [B, 3, 3, N] - self.z = z - z = z.transpose(1, 2).contiguous() - # map to equivariant - else: - z = self.z.contiguous() - self.z = None - - if self.dim == 4: - x = torch.einsum('bijm,bjkm->bikm', x, z) - elif self.dim == 3: - x = torch.einsum('bij,bjk->bik', x, z) - elif self.dim == 5: - x = torch.einsum('bijmn,bjkmn->bikmn', x, z) - else: - raise NotImplementedError(f'dim={self.dim} is not supported') - - return x - - -""" test code -import torch -from multi_part_assembly.models import VNLayerNorm -from multi_part_assembly.utils import random_rotation_matrixs -vn_ln = VNLayerNorm(16) -pc = torch.rand(2, 16, 3, 100) -rmat = random_rotation_matrixs(2) -rot_pc = rmat[:, None] @ pc -ln_pc = vn_ln(pc) -rot_ln_pc = rmat[:, None] @ ln_pc -ln_rot_pc = vn_ln(rot_pc) -(rot_ln_pc - ln_rot_pc).abs().max() -""" diff --git a/multi_part_assembly/models/modules/vnn/transformer.py b/multi_part_assembly/models/modules/vnn/transformer.py deleted file mode 100644 index e2bbcab..0000000 --- a/multi_part_assembly/models/modules/vnn/transformer.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Code borrowed from https://github.com/karpathy/minGPT""" - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops.layers.torch import Rearrange - -from .modules import VNLinear, VNLayerNorm, VNReLU, VNLeakyReLU - - -class VNSelfAttention(nn.Module): - """Inspired by VNT-Net: https://arxiv.org/pdf/2205.09690.pdf. - - Note that, we cannot use dropout in VN networks. - """ - - def __init__(self, d_model, n_head, dropout=0.): - super().__init__() - - assert d_model % n_head == 0 - assert dropout == 0. - self.n_head = n_head - - # key, query, value projections for all heads - self.key = VNLinear(d_model, d_model, dim=4) - self.query = VNLinear(d_model, d_model, dim=4) - self.value = VNLinear(d_model, d_model, dim=4) - self.in_rearrange = Rearrange( - 'B (nh hs) D N -> B nh N (hs D)', nh=n_head, D=3) - - # output projection - self.out_rearrange = Rearrange( - 'B nh N (hs D) -> B (nh hs) D N', nh=n_head, D=3) - self.proj = VNLinear(d_model, d_model, dim=4) - - def forward(self, x, src_key_padding_mask=None): - """Forward pass. - - Args: - x: [B, C, 3, N] - src_key_padding_mask: None or [B, N], True means padded tokens - - Returns: - [B, C, 3, N] - """ - # [B, nh, N, hs*3] - k = self.in_rearrange(self.key(x)) - q = self.in_rearrange(self.query(x)) - v = self.in_rearrange(self.value(x)) - - # [B, nh, N, hs*3] x [B, nh, N, hs*3] --> [B, nh, N, N] - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - if src_key_padding_mask is not None: - assert src_key_padding_mask.dtype == torch.bool - mask = src_key_padding_mask[:, None, None, :] # [B, 1, 1, N] - att = att.masked_fill(mask, float('-inf')) - att = F.softmax(att, dim=-1) - y = att @ v # [B, nh, N, N] x [B, nh, N, hs*3] --> [B, nh, N, hs*3] - # back to [B, C, 3, N] - y = self.out_rearrange(y) - - # output projection - y = self.proj(y) - return y - - -class VNTransformerEncoderLayer(nn.Module): - """VN Transformer block.""" - - def __init__(self, d_model, n_head, relu=True, dropout=0.): - super().__init__() - - assert dropout == 0. - self.ln1 = VNLayerNorm(d_model) - self.ln2 = VNLayerNorm(d_model) - self.attn = VNSelfAttention( - d_model=d_model, - n_head=n_head, - ) - self.mlp = nn.Sequential( - VNLinear(d_model, 4 * d_model, dim=4), - VNReLU(4 * d_model, 4) if relu else VNLeakyReLU(4 * d_model, 4), - VNLinear(4 * d_model, d_model, dim=4), - ) - - def forward(self, x, src_key_padding_mask=None, src_mask=None): - """Forward pass. - - Args: - x: [B, C, 3, N] - src_key_padding_mask: None or [B, N], True means padded tokens - src_mask: useless, to be compatible with nn.TransformerEncoderLayer - - Returns: - [B, C, 3, N] - """ - x = x + self.attn(self.ln1(x), src_key_padding_mask) - x = x + self.mlp(self.ln2(x)) - return x - - -""" test code -import torch -from multi_part_assembly.models import VNTransformerEncoderLayer -from multi_part_assembly.utils import random_rotation_matrixs -vn_attn = VNTransformerEncoderLayer(16, 4, True, 0) -pc = torch.rand(2, 16, 3, 100) -rmat = random_rotation_matrixs(2) -rot_pc = rmat[:, None] @ pc -attn_pc = vn_attn(pc) -rot_attn_pc = rmat[:, None] @ attn_pc -attn_rot_pc = vn_attn(rot_pc) -(rot_attn_pc - attn_rot_pc).abs().max() -""" diff --git a/multi_part_assembly/models/pn_transformer/__init__.py b/multi_part_assembly/models/pn_transformer/__init__.py index 91ed2a1..72db813 100644 --- a/multi_part_assembly/models/pn_transformer/__init__.py +++ b/multi_part_assembly/models/pn_transformer/__init__.py @@ -1,5 +1,3 @@ -from .transformer import TransformerEncoder, VNTransformerEncoder +from .transformer import TransformerEncoder from .network import PNTransformer -from .network_gan import PNTransformerGAN from .network_refine import PNTransformerRefine -from .vn_network import VNPNTransformer, VNPNTransformerV2 diff --git a/multi_part_assembly/models/pn_transformer/network_gan.py b/multi_part_assembly/models/pn_transformer/network_gan.py deleted file mode 100644 index c3cf6cb..0000000 --- a/multi_part_assembly/models/pn_transformer/network_gan.py +++ /dev/null @@ -1,221 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim - -from multi_part_assembly.models import build_encoder -from multi_part_assembly.utils import transform_pc -from multi_part_assembly.utils import CosineAnnealingWarmupRestarts - -from .network import PNTransformer - - -class ShapeDiscriminator(nn.Module): - - def __init__(self, encoder_arch, feat_dim): - super().__init__() - - self.encoder = build_encoder( - encoder_arch, feat_dim=feat_dim, global_feat=True) - self.classifier = nn.Linear(feat_dim, 1) - - def forward(self, x): - feats = self.encoder(x) # [B, C] - pred = self.classifier(feats) # [B, 1] - return pred.squeeze(-1) # [B] - - -class PNTransformerGAN(PNTransformer): - """PNTransformer with discriminator. - - Encoder: PointNet extracting per-part global point cloud features - Correlator: TransformerEncoder perform part interactions - Predictor: MLP-based pose predictor - Discriminator: PointNet based classifier - """ - - def __init__(self, cfg): - super().__init__(cfg) - - self.discriminator = self._init_discriminator() - self.d_npoint = self.cfg.model.discriminator_num_points - - # loss configs - adv_loss = self.cfg.model.discriminator_loss - assert adv_loss in ['mse', 'ce'] - self.adv_loss_fn = nn.MSELoss() if \ - adv_loss == 'mse' else nn.BCEWithLogitsLoss() - - def _init_discriminator(self): - discriminator = ShapeDiscriminator(self.cfg.model.discriminator, - self.cfg.model.pc_feat_dim) - return discriminator - - @staticmethod - def _sample_points(part_pcs, valids, sample_num): - """Sample N points from valid parts to produce a shape point cloud. - - Args: - part_pcs: [B, P, N, 3] - valids: [B, P], 1 is valid, 0 is padded - N: int - """ - B, P, N, _ = part_pcs.shape - part_pcs = part_pcs.flatten(1, 2) # [B, P*N, 3] - # in case `valids` == [1., 1., ..., 1.] (all_ones) - valids = torch.cat([valids, torch.zeros(B, 1).type_as(valids)], dim=1) - num_valid_parts = valids.argmin(1) # find the first `0` in `valids` - all_idx = torch.stack([ - torch.randperm(num_valid_parts[i] * N)[:sample_num] - for i in range(B) - ]).type_as(num_valid_parts) # [B, num_samples] - batch_idx = torch.arange(B)[:, None].type_as(all_idx) - pcs = part_pcs[batch_idx, all_idx] # [B, num_samples, 3] - return pcs - - def _loss_function(self, data_dict, out_dict={}, optimizer_idx=-1): - """Inner loop for sampling loss computation. - - Besides the translation and rotation loss, also compute the GAN loss. - """ - if optimizer_idx == -1: # in eval mode - assert not self.training - return super()._loss_function(data_dict, out_dict) - - batch_size = data_dict['part_pcs'].shape[0] - if optimizer_idx == 0: # g step - loss_dict, out_dict = super()._loss_function(data_dict, out_dict) - real_pts = out_dict['pred_trans_pts'] # [B, P, N, 3] - real_pts = self._sample_points(real_pts, data_dict['part_valids'], - self.d_npoint) # [B, n, 3] - real_logits = self.discriminator(real_pts) # [B] - real = torch.ones(batch_size).type_as(real_pts).detach() - g_loss = self.adv_loss_fn(real_logits, real) - loss_dict.update({'g_loss': g_loss}) - return loss_dict, out_dict - - assert optimizer_idx == 1 # d step - part_pcs, valids = data_dict['part_pcs'], data_dict['part_valids'] - - # generate - forward_dict = { - 'part_pcs': part_pcs, - 'part_valids': valids, - 'part_label': data_dict['part_label'], - 'instance_label': data_dict['instance_label'], - 'pre_pose_feats': out_dict.get('pre_pose_feats', None), - } - with torch.no_grad(): - out_dict = self.forward(forward_dict) - - pred_trans, pred_rot = out_dict['trans'], out_dict['rot'] - gt_trans, gt_rot = data_dict['part_trans'], data_dict['part_rot'] - pred_pts = transform_pc(pred_trans, pred_rot, part_pcs).detach() - pred_pts = self._sample_points(pred_pts, valids, self.d_npoint) - gt_pts = transform_pc(gt_trans, gt_rot, part_pcs).detach() - gt_pts = self._sample_points(gt_pts, valids, self.d_npoint) - real = torch.ones(batch_size).type_as(part_pcs).detach() - fake = torch.zeros(batch_size).type_as(part_pcs).detach() - - real_loss = self.adv_loss_fn(self.discriminator(gt_pts), real) - fake_loss = self.adv_loss_fn(self.discriminator(pred_pts), fake) - d_loss = 0.5 * (real_loss + fake_loss) - return {'d_loss': d_loss}, out_dict - - def loss_function(self, data_dict, optimizer_idx): - """Wrapper for computing MoN loss. - - We sample predictions for multiple times and return the min one. - - Args: - data_dict: from dataloader - optimizer_idx: 0 --> Generator step, 1 --> Discriminator step; - -1 --> in doing eval - """ - if optimizer_idx == -1: # in eval mode - return super().loss_function(data_dict, optimizer_idx) - - loss_dict = None - out_dict = {} - for _ in range(self.sample_iter): - sample_loss, out_dict = self._loss_function( - data_dict, out_dict, optimizer_idx=optimizer_idx) - - if loss_dict is None: - loss_dict = {k: [] for k in sample_loss.keys()} - for k, v in sample_loss.items(): - loss_dict[k].append(v) - loss_dict = {k: torch.stack(v, dim=0) for k, v in loss_dict.items()} - - if optimizer_idx == 1: # d step, only `d_loss` - d_loss = loss_dict['d_loss'].mean() - return { - 'd_loss': d_loss, - 'loss': d_loss * self.cfg.loss.d_loss_w, - } - - assert optimizer_idx == 0 - - # `g_loss` doesn't involve in MoN loss computation - g_loss = loss_dict.pop('g_loss').mean() - - # take the min for each data in the batch - total_loss = 0. - for k, v in loss_dict.items(): - if 'loss' in k: # we may log some other metrics in eval, e.g. acc - total_loss += v * eval(f'self.cfg.loss.{k}_w') # weighting - loss_dict['loss'] = total_loss - - # `total_loss` is of shape [sample_iter, B] - min_idx = total_loss.argmin(0) # [B] - B = min_idx.shape[0] - batch_idx = torch.arange(B).type_as(min_idx) - loss_dict = { - k: v[min_idx, batch_idx].mean() - for k, v in loss_dict.items() - } - - # add `g_loss` - loss_dict['g_loss'] = g_loss - loss_dict['loss'] = loss_dict['loss'] + g_loss * self.cfg.loss.g_loss_w - - return loss_dict - - def configure_optimizers(self): - """Build optimizer and lr scheduler.""" - g_lr = self.cfg.optimizer.g_lr - d_lr = self.cfg.optimizer.d_lr - g_opt = optim.Adam( - list(self.encoder.parameters()) + - list(self.corr_module.parameters()) + - list(self.pose_predictor.parameters()), - lr=g_lr) - d_opt = optim.Adam(self.discriminator.parameters(), lr=d_lr) - - if self.cfg.optimizer.lr_scheduler: - assert self.cfg.optimizer.lr_scheduler in ['cosine'] - clip_lr = min(g_lr, d_lr) / self.cfg.optimizer.lr_decay_factor - total_epochs = self.cfg.exp.num_epochs - warmup_epochs = int(total_epochs * self.cfg.optimizer.warmup_ratio) - g_scheduler = CosineAnnealingWarmupRestarts( - g_opt, - total_epochs, - max_lr=g_lr, - min_lr=clip_lr, - warmup_steps=warmup_epochs) - d_scheduler = CosineAnnealingWarmupRestarts( - d_opt, - total_epochs, - max_lr=d_lr, - min_lr=clip_lr, - warmup_steps=warmup_epochs) - return ( - [g_opt, d_opt], - [{ - 'scheduler': g_scheduler, - 'interval': 'epoch', - }, { - 'scheduler': d_scheduler, - 'interval': 'epoch', - }], - ) - return [g_opt, d_opt] diff --git a/multi_part_assembly/models/pn_transformer/transformer.py b/multi_part_assembly/models/pn_transformer/transformer.py index 1a7aeab..7caa935 100644 --- a/multi_part_assembly/models/pn_transformer/transformer.py +++ b/multi_part_assembly/models/pn_transformer/transformer.py @@ -1,7 +1,5 @@ import torch.nn as nn -from multi_part_assembly.models import VNEqFeature - def build_transformer_encoder( d_model, @@ -79,67 +77,3 @@ def forward(self, tokens, valid_masks): pad_masks = None out = self.transformer_encoder(tokens, src_key_padding_mask=pad_masks) return self.out_fc(out) - - -class VNTransformerEncoder(TransformerEncoder): - """VNTransformer encoder with padding_mask. - - It first maps tokens to invariant features. - Then, it applies the normal TransformerEncoder to perform interactions. - Finally, it maps the invariant features back to the rotation of tokens. - """ - - def __init__( - self, - d_model, - num_heads, - num_layers, - dropout=0., - out_dim=None, - ): - super().__init__( - d_model=d_model * 3, - num_heads=num_heads, - ffn_dim=d_model * 3 * 4, - num_layers=num_layers, - norm_first=True, - dropout=dropout, - out_dim=out_dim, - ) - - # canonicalizer, map to invariant space then back to equivariant space - self.feats_can = VNEqFeature(d_model, dim=4) - - def forward(self, tokens, valid_masks): - """Forward pass. - - Args: - tokens: [B, C, 3, N] - valid_masks: [B, N], True for valid, False for padded - - Returns: - torch.Tensor: [B, C, 3, N] - """ - # map tokens to invariant features - tokens_in = self.feats_can(tokens).flatten(1, 2) # [B, C*3, N] - tokens_in = tokens_in.transpose(1, 2).contiguous() # [B, N, C*3] - out_in = super().forward(tokens_in, valid_masks) # [B, N, C*3] - # back to [B, C, 3, N] - out_in = out_in.transpose(1, 2).unflatten(1, (-1, 3)).contiguous() - out_eq = self.feats_can(out_in) - return out_eq - - -""" test code -import torch -from multi_part_assembly.models.pn_transformer import VNTransformerEncoder -from multi_part_assembly.utils import random_rotation_matrixs -vn_trans = VNTransformerEncoder(16, 4, 2, 0.).eval() -pc = torch.rand(2, 16, 3, 10) # 10 parts -rmat = random_rotation_matrixs((2, 10)) # [2, 10, 3, 3] -rot_pc = (rmat @ pc.transpose(1, -1)).transpose(1, -1) -trans_pc = vn_trans(pc, None) -rot_trans_pc = (rmat @ trans_pc.transpose(1, -1)).transpose(1, -1) -trans_rot_pc = vn_trans(rot_pc, None) -(rot_trans_pc - trans_rot_pc).abs().max() -""" diff --git a/multi_part_assembly/models/pn_transformer/vn_network.py b/multi_part_assembly/models/pn_transformer/vn_network.py deleted file mode 100644 index d27e9a4..0000000 --- a/multi_part_assembly/models/pn_transformer/vn_network.py +++ /dev/null @@ -1,177 +0,0 @@ -import torch - -from multi_part_assembly.models import build_encoder, VNEqFeature, VNPoseRegressor, PoseRegressor - -from .network import PNTransformer -from .transformer import VNTransformerEncoder, TransformerEncoder - - -class VNPNTransformer(PNTransformer): - """SO(3) equivariant PointNet-Transformer based multi-part assembly model. - - This model should only be used in geometric assembly, and use rotation - matrix as the rotation representation. - This is because 1) the 6d rotation representation can be parametrized as - (2, 3) matrix, which is compatible with the (C, 3) shape in VN models - 2) in semantic assembly we also input instance label and random noise, - which cannot preserve rotation equivariance. - - Encoder: VNPointNet extracting per-part global point cloud features - Correlator: VNTransformerEncoder perform part interactions - Predictor: VN MLP for rotation and VN-In MLP for translation - """ - - def __init__(self, cfg): - super().__init__(cfg) - - # see the above class docstring - assert self.rot_type == 'rmat', 'VNPNTransformer should predict rmat' - assert not self.semantic, 'VNPNTransformer is for geometric assembly' - - def _init_encoder(self): - """Part point cloud encoder.""" - encoder = build_encoder( - self.cfg.model.encoder, - feat_dim=self.pc_feat_dim, - global_feat=True, - pool1='mean', - pool2='max', - ) - return encoder - - def _init_corr_module(self): - """Part feature interaction module.""" - corr_module = VNTransformerEncoder( - d_model=self.pc_feat_dim, - num_heads=self.cfg.model.transformer_heads, - num_layers=self.cfg.model.transformer_layers, - dropout=0., - ) - return corr_module - - def _init_pose_predictor(self): - """Final pose estimator.""" - # only use feature as input to in VN models - assert self.cfg.loss.noise_dim == 0 - pose_predictor = VNPoseRegressor( - feat_dim=self.pc_feat_dim, - rot_type=self.rot_type, - ) - return pose_predictor - - def _extract_part_feats(self, part_pcs, part_valids): - """Extract per-part point cloud features.""" - B, P, N, _ = part_pcs.shape # [B, P, N, 3] - valid_mask = (part_valids == 1) - # shared-weight encoder - valid_pcs = part_pcs[valid_mask] # [n, N, 3] - valid_feats = self.encoder(valid_pcs) # [n, C, 3] - pc_feats = torch.zeros(B, P, self.pc_feat_dim, 3).type_as(valid_feats) - pc_feats[valid_mask] = valid_feats - return pc_feats - - def forward(self, data_dict): - """Forward pass to predict poses for each part. - - Args: - data_dict should contains: - - part_pcs: [B, P, N, 3] - - part_valids: [B, P], 1 are valid parts, 0 are padded parts - may contains: - - pre_pose_feats: [B, P, C', 3] (reused) or None - """ - feats = data_dict.get('pre_pose_feats', None) - - if feats is None: - part_pcs = data_dict['part_pcs'] - part_valids = data_dict['part_valids'] - pc_feats = self._extract_part_feats(part_pcs, part_valids) - # transformer feature fusion - # [B, P, C, 3] --> [B, C, 3, P] - pc_feats = pc_feats.permute(0, 2, 3, 1).contiguous() - valid_mask = (part_valids == 1) # [B, P] - corr_feats = self.corr_module(pc_feats, valid_mask) # [B, C, 3, P] - # MLP predict poses - # [B, C, 3, P] --> [B, P, C, 3] - feats = corr_feats.permute(0, 3, 1, 2).contiguous() - rot, trans = self.pose_predictor(feats) - rot = self._wrap_rotation(rot) - - pred_dict = { - 'rot': rot, # [B, P, 4/(3, 3)], Rotation3D - 'trans': trans, # [B, P, 3] - 'pre_pose_feats': feats, # [B, P, C', 3] - } - return pred_dict - - -class VNPNTransformerV2(VNPNTransformer): - - def __init__(self, cfg): - super().__init__(cfg) - - self.feats_can = VNEqFeature( - self.pc_feat_dim, dim=4, use_rmat=self.cfg.model.rmat_can) - - def _init_corr_module(self): - """Part feature interaction module.""" - corr_module = TransformerEncoder( - d_model=self.pc_feat_dim * 3, - num_heads=self.cfg.model.transformer_heads, - ffn_dim=self.pc_feat_dim * 3 * 4, - num_layers=self.cfg.model.transformer_layers, - norm_first=True, - dropout=0., - ) - return corr_module - - def _init_pose_predictor(self): - """Final pose estimator.""" - # only use feature as input to in VN models - assert self.cfg.loss.noise_dim == 0 - pose_predictor = PoseRegressor( - feat_dim=self.pc_feat_dim * 3, - rot_type=self.rot_type, - norm_rot=self.cfg.model.rmat_can, # use rotation matrix in can - ) - return pose_predictor - - def forward(self, data_dict): - """Forward pass to predict poses for each part. - - Args: - data_dict should contains: - - part_pcs: [B, P, N, 3] - - part_valids: [B, P], 1 are valid parts, 0 are padded parts - may contains: - - pre_pose_feats: [B, P, C'*3] (reused) or None - """ - feats = data_dict.get('pre_pose_feats', None) - assert feats is None - - part_pcs = data_dict['part_pcs'] - part_valids = data_dict['part_valids'] - pc_feats = self._extract_part_feats(part_pcs, part_valids) - # [B, P, C, 3] --> [B, C, 3, P] - pc_feats = pc_feats.permute(0, 2, 3, 1).contiguous() - pc_feats = self.feats_can(pc_feats) # [B, C, 3, P], invariant - # to [B, P, C*3] - pc_feats = pc_feats.flatten(1, 2).transpose(1, 2).contiguous() - # transformer feature fusion - valid_mask = (part_valids == 1) # [B, P] - feats = self.corr_module(pc_feats, valid_mask) # [B, P, C*3] - rot, trans = self.pose_predictor(feats) # [B, P, 6], [B, P, 3] - # translation prediction is invariant, which is what we want - # we need to make rotation prediction equivariant - # to [B, 2, 3, P] - rot = rot.transpose(1, 2).unflatten(1, (2, 3)).contiguous() - rot = self.feats_can(rot) # [B, 2, 3, P], equivariant - rot = rot.permute(0, 3, 1, 2).contiguous() # [B, P, 2, 3] - rot = self._wrap_rotation(rot) - - pred_dict = { - 'rot': rot, # [B, P, 4/(3, 3)], Rotation3D - 'trans': trans, # [B, P, 3] - 'pre_pose_feats': feats, # [B, P, C', 3] - } - return pred_dict diff --git a/multi_part_assembly/utils/__init__.py b/multi_part_assembly/utils/__init__.py index c9ad745..ef10c6f 100644 --- a/multi_part_assembly/utils/__init__.py +++ b/multi_part_assembly/utils/__init__.py @@ -7,6 +7,6 @@ from .utils import colorize_part_pc, filter_wd_parameters, _get_clones, \ pickle_load, pickle_dump, save_pc from .eval_utils import trans_metrics, rot_metrics, calc_part_acc, \ - calc_connectivity_acc, rot_geodesic_dist, relative_pose_metrics + calc_connectivity_acc from .lr import CosineAnnealingWarmupRestarts, LinearAnnealingWarmup from .config_utils import merge_cfg diff --git a/multi_part_assembly/utils/eval_utils.py b/multi_part_assembly/utils/eval_utils.py index ea5ca0e..8a8d976 100644 --- a/multi_part_assembly/utils/eval_utils.py +++ b/multi_part_assembly/utils/eval_utils.py @@ -197,109 +197,3 @@ def rot_metrics(rot1, rot2, valids, metric): metric_per_data = diff.abs().mean(dim=-1) metric_per_data = _valid_mean(metric_per_data, valids) return metric_per_data - - -@torch.no_grad() -def rot_geodesic_dist(rot1, rot2, valids): - """Evaluation metrics for rotation using geodesic distance. - - According to https://www.cs.cmu.edu/~cga/dynopt/readings/Rmetric.pdf - Section 4, euler angles MSE is not a good metric. - So we adopt the `Geodesic on the Unit Sphere` metric introduced in the - paper Section 3.6. The authors prove that this equals to - `2·arccos(|q1·q2|)` (see eq.34 of the paper). - - Args: - rot1: [B, P, 4/(3, 3)], Rotation3D, quat or rmat - rot2: [B, P, 4/(3, 3)], Rotation3D, quat or rmat - valids: [B, P], 1 for input parts, 0 for padded parts - - Returns: - [B], metric per data in the batch - """ - quat1 = rot1.to_quat() # [B, P, 4] - quat2 = rot2.to_quat() - metric_per_data = 2. * torch.acos((quat1 * quat2).sum(dim=-1).abs()) - metric_per_data = _valid_mean(metric_per_data, valids) - return metric_per_data * 180. / math.pi # to degree - - -@torch.no_grad() -def relative_pose_metrics(trans1, trans2, rot1, rot2, valids): - """Relative pose error for geometric assembly. - - Since it's hard to define canonical pose for each shape (e.g. symmetry), - we take each part as the canonical pose, calculate the relative pose - errors from other shapes to it, and take the min of them. - - Args: - trans1: [B, P, 3] - trans2: [B, P, 3] - rot1: [B, P, 4/(3, 3)], Rotation3D, quat or rmat - rot2: [B, P, 4/(3, 3)], Rotation3D, quat or rmat - valids: [B, P], 1 for input parts, 0 for padded parts - - Returns: - [B], [B], translation/rotation error per data in the batch - """ - B, P = valids.shape - valids = valids.float() - rmat1 = rot1.to_rmat() # [B, P, 3, 3] - rmat2 = rot2.to_rmat() - - def _get_relative_pose(R, T): - """Get relative pose from canonical pose. - - Args: - R: [B, P, 3, 3] - T: [B, P, 3] - """ - R1 = R.unsqueeze(2) # [B, P, 1, 3, 3] - R2 = R.unsqueeze(1) # [B, 1, P, 3, 3] - T1 = T.unsqueeze(2) # [B, P, 1, 3] - T2 = T.unsqueeze(1) # [B, 1, P, 3] - rel_R = R1.transpose(-1, -2) @ R2 # [B, P, P, 3, 3] - rel_T = (R1.transpose(-1, -2) @ ( - (T2 - T1)[..., None]))[..., 0] # [B, P, P, 3] - # [B, i, j, ...] is when i is canonical, relative pose from j to i - return rel_R, rel_T - - rel_R1, rel_T1 = _get_relative_pose(rmat1, trans1) - rel_R2, rel_T2 = _get_relative_pose(rmat2, trans2) - # tile valid_labels for each canonical pose - rel_valids = valids.unsqueeze(1).repeat(1, P, 1) # [B, P, P] - # take all the elements except the diagonal - mask = torch.eye(P).bool().unsqueeze(0).to(valids.device) # [1, P, P] - mask = (~mask).repeat(B, 1, 1) # [B, P, P] - rel_R1 = rel_R1[mask].unflatten(0, (B * P, P - 1)) - rel_R2 = rel_R2[mask].unflatten(0, (B * P, P - 1)) - rel_T1 = rel_T1[mask].unflatten(0, (B * P, P - 1)) - rel_T2 = rel_T2[mask].unflatten(0, (B * P, P - 1)) - rel_valids = rel_valids[mask].unflatten(0, (B * P, P - 1)) - rel_R1 = Rotation3D(rel_R1, rot_type='rmat').convert('quat') - rel_R2 = Rotation3D(rel_R2, rot_type='rmat').convert('quat') - - def _min_error(errors, min_idx=None): - """Take the canonical pose with the min error.""" - # errors: [B*P], should mask out invalid parts - errors = errors.reshape(B, P) - if min_idx is None: - shift_errors = errors + 1e9 * (1. - valids) - min_idx = shift_errors.argmin(dim=1, keepdim=True) # [B, 1] - min_errors = torch.gather(errors, dim=1, index=min_idx) - return min_errors, min_idx - - metric_dict = {} - idx = None - for metric in ['mse', 'rmse', 'mae']: - # we use relative translation MSE to select the canonical pose - trans_errors = trans_metrics(rel_T1, rel_T2, rel_valids, metric=metric) - metric_dict[f'rel_trans_{metric}'], idx = _min_error( - trans_errors, min_idx=idx) - rot_errors = rot_metrics(rel_R1, rel_R2, rel_valids, metric=metric) - metric_dict[f'rel_rot_{metric}'], idx = _min_error( - rot_errors, min_idx=idx) - metric_dict['rel_geo_rot'], _ = _min_error( - rot_geodesic_dist(rel_R1, rel_R2, rel_valids), min_idx=idx) - - return metric_dict diff --git a/multi_part_assembly/version.py b/multi_part_assembly/version.py index b794fd4..7fd229a 100644 --- a/multi_part_assembly/version.py +++ b/multi_part_assembly/version.py @@ -1 +1 @@ -__version__ = '0.1.0' +__version__ = '0.2.0' diff --git a/scripts/collect_test.py b/scripts/collect_test.py index 2c15972..0b86899 100644 --- a/scripts/collect_test.py +++ b/scripts/collect_test.py @@ -39,14 +39,14 @@ def test(cfg): strategy='dp' if len(all_gpus) > 1 else None, ) + # TODO: modify this to fit in the metrics you want to report all_metrics = { 'rot_rmse': 1., 'rot_mae': 1., - 'geo_rot': 1., - 'trans_rmse': 100., - 'trans_mae': 100., - 'transform_pt_cd_loss': 1000., - 'part_acc': 100., + 'trans_rmse': 100., # presented as \times 1e-2 in the table + 'trans_mae': 100., # presented as \times 1e-2 in the table + 'transform_pt_cd_loss': 1000., # presented as \times 1e-3 in the table + 'part_acc': 100., # presented in % in the table } # performance on all categories @@ -71,9 +71,14 @@ def test(cfg): all_results[metric] = np.mean(all_results[metric]).round(1) print(f'{metric}: {all_results[metric]}') # format for latex table + print('\n##############################################') + print('Results averaged over all categories:') result = [str(all_results[metric]) for metric in all_metrics.keys()] print(' & '.join(result)) + if not hasattr(cfg.data, 'all_category'): + return + # iterate over all categories all_category = cfg.data.all_category all_results = { @@ -121,12 +126,23 @@ def test(cfg): } all_results = {k: np.array(v).round(1) for k, v in all_results.items()} # format for latex table + # per-category results + print('\n##############################################') + print('Results per category:') for metric, result in all_results.items(): print(f'{metric}:') result = result.tolist() result.append(np.nanmean(result).round(1)) # per-category mean result = [str(res) for res in result] print(' & '.join(result)) + all_results[metric] = result + # averaged over all categories + print('\n##############################################') + print('Results averaged over all categories:') + all_metric_names = list(all_metrics.keys()) + result = [str(all_results[metric][-1]) for metric in all_metric_names] + print(' & '.join(all_metric_names)) + print(' & '.join(result)) print('Done testing...') diff --git a/scripts/debug.py b/scripts/debug.py deleted file mode 100644 index 5e39f80..0000000 --- a/scripts/debug.py +++ /dev/null @@ -1,23 +0,0 @@ -import os -import sys -import importlib - -from multi_part_assembly.datasets import build_dataloader -from multi_part_assembly.models import build_model - - -def build(cfg_file): - sys.path.append(os.path.dirname(cfg_file)) - cfg = importlib.import_module(os.path.basename(cfg_file)[:-3]) - cfg = cfg.get_cfg_defaults() - - cfg.freeze() - print(cfg) - - # Initialize model - model = build_model(cfg) - - # Initialize dataloaders - train_loader, val_loader = build_dataloader(cfg) - - return model, train_loader, val_loader, cfg diff --git a/scripts/dup_run_sbatch.sh b/scripts/dup_run_sbatch.sh index 6630707..8998d4a 100755 --- a/scripts/dup_run_sbatch.sh +++ b/scripts/dup_run_sbatch.sh @@ -5,15 +5,16 @@ ####################################################################### # An example usage: -# GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh \ -# rtx6000 test-sbatch test.py config.py --fp16 --cudnn +# GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh \ +# rtx6000 test-sbatch ./scripts/train.py config.py --fp16 --cudnn ####################################################################### # read args from command line GPUS=${GPUS:-1} -CPUS_PER_TASK=${CPUS_PER_TASK:-8} +CPUS_PER_GPU=${CPUS_PER_GPU:-8} MEM_PER_CPU=${MEM_PER_CPU:-5} QOS=${QOS:-normal} +TIME=${TIME:-0} REPEAT=${REPEAT:-3} PY_ARGS=${@:5} diff --git a/scripts/sbatch_run.sh b/scripts/sbatch_run.sh index d3dc9de..2c8d437 100755 --- a/scripts/sbatch_run.sh +++ b/scripts/sbatch_run.sh @@ -6,15 +6,16 @@ ####################################################################### # An example usage: -# GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal ./scripts/sbatch_run.sh rtx6000 train-sbatch \ +# GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal ./scripts/sbatch_run.sh rtx6000 train-sbatch \ # ./scripts/train.py --cfg_file config.py ####################################################################### # read args from command line GPUS=${GPUS:-1} -CPUS_PER_TASK=${CPUS_PER_TASK:-8} +CPUS_PER_GPU=${CPUS_PER_GPU:-8} MEM_PER_CPU=${MEM_PER_CPU:-5} QOS=${QOS:-normal} +TIME=${TIME:-0} PY_ARGS=${@:4} PARTITION=$1 @@ -25,6 +26,7 @@ SLRM_NAME="${JOB_NAME/\//"_"}" LOG_DIR=checkpoint/$JOB_NAME DATETIME=$(date "+%Y-%m-%d_%H:%M:%S") LOG_FILE=$LOG_DIR/${DATETIME}.log +CPUS_PER_TASK=$((GPUS * CPUS_PER_GPU)) # set up log output folder mkdir -p $LOG_DIR @@ -39,15 +41,13 @@ echo "#!/bin/bash #SBATCH --open-mode=append #SBATCH --partition=$PARTITION # self-explanatory, set to your preference (e.g. gpu or cpu on MaRS, p100, t4, or cpu on Vaughan) #SBATCH --cpus-per-task=$CPUS_PER_TASK # self-explanatory, set to your preference -#SBATCH --ntasks=$GPUS -#SBATCH --ntasks-per-node=$GPUS +#SBATCH --ntasks=1 +#SBATCH --ntasks-per-node=1 #SBATCH --mem-per-cpu=${MEM_PER_CPU}G # self-explanatory, set to your preference #SBATCH --gres=gpu:$GPUS # NOTE: you need a GPU for CUDA support; self-explanatory, set to your preference #SBATCH --nodes=1 -#SBATCH --qos=$QOS # for 'high' and 'deadline' QoS, refer to https://support.vectorinstitute.ai/AboutVaughan2 - -# link /checkpoint to current folder -# ln -sfn /checkpoint/\$USER/\$SLURM_JOB_ID $LOG_DIR +#SBATCH --qos=$QOS # self-explanatory, set to your preference +#SBATCH --time=$TIME # running time limit, 0 as unlimited # log some necessary environment params echo \$SLURM_JOB_ID >> $LOG_FILE # log the job id diff --git a/scripts/test.py b/scripts/test.py index 49f4347..5cf51b6 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -31,15 +31,15 @@ def test(cfg): return # if `args.category` is 'all', we also compute per-category results + # TODO: modify this to fit in the metrics you want to report all_category = cfg.data.all_category all_metrics = { 'rot_rmse': 1., 'rot_mae': 1., - 'geo_rot': 1., - 'trans_rmse': 100., - 'trans_mae': 100., - 'transform_pt_cd_loss': 1000., - 'part_acc': 100., + 'trans_rmse': 100., # presented as \times 1e-2 in the table + 'trans_mae': 100., # presented as \times 1e-2 in the table + 'transform_pt_cd_loss': 1000., # presented as \times 1e-3 in the table + 'part_acc': 100., # presented in % in the table } all_results = {metric: [] for metric in all_metrics.keys()} for cat in all_category: diff --git a/scripts/train.py b/scripts/train.py index 0be71b4..16d266c 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -4,6 +4,7 @@ import argparse import importlib +import torch import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor @@ -26,11 +27,14 @@ def main(cfg): ckp_dir = os.path.join(cfg.exp.ckp_dir, cfg_name, 'models') os.makedirs(os.path.dirname(ckp_dir), exist_ok=True) - # on clusters, quota is limited - # soft link temp space for checkpointing + # on clusters, quota under user dir is usually limited + # soft link to save the weights in temp space for checkpointing # TODO: modify this if you are not running on clusters - if SLURM_JOB_ID and os.path.isdir('/checkpoint/'): + CHECKPOINT_DIR = '/checkpoint/' # '' + if SLURM_JOB_ID and CHECKPOINT_DIR and os.path.isdir(CHECKPOINT_DIR): if not os.path.exists(ckp_dir): + # on my cluster, the temp dir is /checkpoint/$USER/$SLURM_JOB_ID + # TODO: modify this if your cluster is different usr = pwd.getpwuid(os.getuid())[0] os.system(r'ln -s /checkpoint/{}/{}/ {}'.format( usr, SLURM_JOB_ID, ckp_dir)) @@ -39,10 +43,10 @@ def main(cfg): # it's not good to hard-code the wandb id # but on preemption clusters, we want the job to resume the same wandb - # process after resuming training + # process after resuming training (i.e. drawing the same graph) # so we have to keep the same wandb id # TODO: modify this if you are not running on preemption clusters - preemption = True + preemption = True # False if SLURM_JOB_ID and preemption: logger_id = logger_name = f'{cfg_name}-{SLURM_JOB_ID}' else: @@ -101,7 +105,15 @@ def main(cfg): print(f'INFO: automatically detect checkpoint {last_ckp}') ckp_path = os.path.join(ckp_dir, last_ckp) elif cfg.exp.weight_file: - ckp_path = cfg.exp.weight_file + # check if it has trainint states, or just a model weight + ckp = torch.load(cfg.exp.weight_file, map_location='cpu') + # if it has, then it's a checkpoint compatible with pl + if 'state_dict' in ckp.keys(): + ckp_path = cfg.exp.weight_file + # if it's just a weight, then manually load it to the model + else: + ckp_path = None + model.load_state_dict(ckp) else: ckp_path = None @@ -125,9 +137,11 @@ def main(cfg): cfg = importlib.import_module(os.path.basename(args.cfg_file)[:-3]) cfg = cfg.get_cfg_defaults() - # TODO: modify this line if you can run DDP on the cluster - parallel_strategy = 'dp' # 'ddp' + # TODO: modify this if you cannot run DDP training, and want to use DP + parallel_strategy = 'ddp' # 'dp' cfg.exp.gpus = args.gpus + # manually increase batch_size according to the number of GPUs in DP + # not necessary in DDP because it's already per-GPU batch size if len(cfg.exp.gpus) > 1 and parallel_strategy == 'dp': cfg.exp.batch_size *= len(cfg.exp.gpus) cfg.exp.num_workers *= len(cfg.exp.gpus) diff --git a/scripts/train_everyday_categories.sh b/scripts/train_everyday_categories.sh index 772d679..0d5916e 100755 --- a/scripts/train_everyday_categories.sh +++ b/scripts/train_everyday_categories.sh @@ -4,7 +4,7 @@ ####################################################################### # An example usage: -# ./scripts/train_everyday_categories.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 global-everyday-xxx-CATEGORY ./scripts/train.py config.py --fp16 --cudnn" config.py +# ./scripts/train_everyday_categories.sh "GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 global-everyday-xxx-CATEGORY ./scripts/train.py config.py --fp16 --cudnn" config.py ####################################################################### CMD=$1 diff --git a/scripts/train_one_category.sh b/scripts/train_one_category.sh index 9c68f26..20c86da 100755 --- a/scripts/train_one_category.sh +++ b/scripts/train_one_category.sh @@ -4,7 +4,7 @@ ####################################################################### # An example usage: -# ./scripts/train_one_category.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 global-everyday-xxx-CATEGORY ./scripts/train.py config.py --fp16 --cudnn" config.py Bottle +# ./scripts/train_one_category.sh "GPUS=1 CPUS_PER_GPU=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 global-everyday-xxx-CATEGORY ./scripts/train.py config.py --fp16 --cudnn" config.py Bottle ####################################################################### CMD=$1 diff --git a/scripts/vis.py b/scripts/vis.py index 97707a0..f068a04 100644 --- a/scripts/vis.py +++ b/scripts/vis.py @@ -35,10 +35,7 @@ def visualize(cfg): batch = {k: v.float().cuda() for k, v in batch.items()} out_dict = model(batch) # trans/rot: [B, P, 3/4/(3, 3)] loss_dict, _ = model.module._calc_loss(out_dict, batch) # loss is [B] - # TODO: the criterion to select examples - # loss = -loss_dict['part_acc'] - # loss = loss_dict['trans_mae'] + loss_dict['rot_mae'] / 1000. - # loss = loss_dict['transform_pt_cd_loss'] + # the criterion to cherry-pick examples loss = loss_dict['rot_pt_l2_loss'] + loss_dict['trans_mae'] # convert all the rotations to quaternion for simplicity out_dict = {