Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuziyi616 committed Aug 5, 2022
1 parent 8dc1fd3 commit 00a927d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 16 deletions.
10 changes: 9 additions & 1 deletion docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,20 @@ python scripts/test.py --cfg_file $CFG --weight path/to/weight --category all

We will print the metrics on each category and the averaged results.

We also provide script to test your per-category trained models (currently only support `everyday` subset of the Breaking-Bad dataset). Suppose you train the models by running `./scrips/train_everyday_categories.sh $COMMAND $CFG.py`. Then the model checkpoint will be saved in `checkpoint/$CFG-$CATEGORY-dup$X`. To collect the performance on each category, run:
We also provide script to gather results trained under multiple random seeds. Suppose you train the models per category by running `./scrips/train_everyday_categories.sh $COMMAND $CFG.py`. Then the model checkpoint will be saved in `checkpoint/$CFG-$CATEGORY-dup$X`. To collect the performance on each category, run:

```
python scripts/collect_test.py --cfg_file $CFG.py --num_dup $X --ckp_suffix checkpoint/$CFG-
```

The per-category results will be formatted into latex table style for the ease of paper writing.

Besides, if you train the models on all categories by running `GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=$NUM_REPEAT ./scripts/dup_run_sbatch.sh $PARTITION $JOB_NAME ./scripts/train.py $CFG --other_args...`. Then the model checkpoint will be saved in `checkpoint/$CFG-dup$X`. To collect the performance, simply adding a `--train_all` flag:

```
python scripts/collect_test.py --cfg_file $CFG.py --num_dup $X --ckp_suffix checkpoint/$CFG- --train_all
```

You can again control the number of pieces and GPUs to use.

## Visualization
Expand Down
14 changes: 9 additions & 5 deletions multi_part_assembly/models/modules/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,16 @@ def test_step(self, data_dict, batch_idx):
def test_epoch_end(self, outputs):
# avg_loss among all data
# we need to consider different batch_size
func = torch.tensor if \
isinstance(outputs[0]['batch_size'], int) else torch.stack
batch_sizes = func([output.pop('batch_size') for output in outputs
]).type_as(outputs[0]['loss']) # [num_batches]
if isinstance(outputs[0]['batch_size'], int):
func_bs = torch.tensor
func_loss = torch.stack
else:
func_bs = torch.cat
func_loss = torch.cat
batch_sizes = func_bs([output.pop('batch_size') for output in outputs
]).type_as(outputs[0]['loss']) # [num_batches]
losses = {
f'test/{k}': torch.stack([output[k] for output in outputs])
f'test/{k}': func_loss([output[k] for output in outputs])
for k in outputs[0].keys()
} # each is [num_batches], stacked avg loss in each batch
avg_loss = {
Expand Down
3 changes: 2 additions & 1 deletion multi_part_assembly/utils/eval_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import math

import torch

Expand Down Expand Up @@ -220,7 +221,7 @@ def rot_geodesic_dist(rot1, rot2, valids):
quat2 = rot2.to_quat()
metric_per_data = 2. * torch.acos((quat1 * quat2).sum(dim=-1).abs())
metric_per_data = _valid_mean(metric_per_data, valids)
return metric_per_data
return metric_per_data * 180. / math.pi # to degree


@torch.no_grad()
Expand Down
46 changes: 40 additions & 6 deletions scripts/collect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def test(cfg):
strategy='dp' if len(all_gpus) > 1 else None,
)

# iterate over all per-category trained models
all_category = cfg.data.all_category
all_metrics = {
'rot_rmse': 1.,
'rot_mae': 1.,
Expand All @@ -50,6 +48,34 @@ def test(cfg):
'transform_pt_cd_loss': 1000.,
'part_acc': 100.,
}

# performance on all categories
if args.train_all:
all_results = {metric: [] for metric in all_metrics.keys()}
ckp_suffix = f'{args.ckp_suffix}dup'
_, val_loader = build_dataloader(cfg)
for i in range(1, args.num_dup + 1, 1):
ckp_folder = f'{ckp_suffix}{i}/models'
try:
ckp_path = find_last_ckp(ckp_folder)
except AssertionError:
continue
trainer.test(model, val_loader, ckpt_path=ckp_path)
results = model.test_results
results = {k[5:]: v.cpu().numpy() for k, v in results.items()}
for metric in all_metrics.keys():
all_results[metric].append(results[metric] *
all_metrics[metric])
# average over `dup` runs
for metric in all_metrics.keys():
all_results[metric] = np.mean(all_results[metric]).round(1)
print(f'{metric}: {all_results[metric]}')
# format for latex table
result = [str(all_results[metric]) for metric in all_metrics.keys()]
print(' & '.join(result))

# iterate over all categories
all_category = cfg.data.all_category
all_results = {
cat: {metric: []
for metric in all_metrics.keys()}
Expand All @@ -65,7 +91,13 @@ def test(cfg):
continue
# iterate over all dup-trained models
# 'dup1', 'dup2', 'dup3', ...
ckp_suffix = f'{args.ckp_suffix}{cat}-dup'
# if the model is trained on all categories together
# then there is only one weight
if args.train_all:
ckp_suffix = f'{args.ckp_suffix}dup'
# else there is one weight per category
else:
ckp_suffix = f'{args.ckp_suffix}{cat}-dup'
for i in range(1, args.num_dup + 1, 1):
ckp_folder = f'{ckp_suffix}{i}/models'
try:
Expand All @@ -76,7 +108,8 @@ def test(cfg):
results = model.test_results
results = {k[5:]: v.cpu().numpy() for k, v in results.items()}
for metric in all_metrics.keys():
all_results[cat][metric].append(results[metric])
all_results[cat][metric].append(results[metric] *
all_metrics[metric])
# average over `dup` runs
for cat in all_category:
for metric in all_metrics.keys():
Expand All @@ -91,8 +124,7 @@ def test(cfg):
for metric, result in all_results.items():
print(f'{metric}:')
result = result.tolist()
# per-category mean, scale it for scientific notation
result.append(np.nanmean(result).round(1) * all_metrics[metric])
result.append(np.nanmean(result).round(1)) # per-category mean
result = [str(res) for res in result]
print(' & '.join(result))

Expand All @@ -107,6 +139,8 @@ def test(cfg):
parser.add_argument('--num_dup', type=int, default=3)
parser.add_argument('--ckp_suffix', type=str, required=True)
parser.add_argument('--gpus', nargs='+', default=[0], type=int)
parser.add_argument(
'--train_all', action='store_true', help='trained on all categories')
args = parser.parse_args()

sys.path.append(os.path.dirname(args.cfg_file))
Expand Down
5 changes: 2 additions & 3 deletions scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,13 @@ def test(cfg):
results = model.test_results
results = {k[5:]: v.detach().cpu().numpy() for k, v in results.items()}
for metric in all_metrics.keys():
all_results[metric].append(results[metric])
all_results[metric].append(results[metric] * all_metrics[metric])
all_results = {k: np.array(v).round(1) for k, v in all_results.items()}
# format for latex table
for metric, result in all_results.items():
print(f'{metric}:')
result = result.tolist()
# per-category mean, scale it for scientific notation
result.append(np.mean(result).round(1) * all_metrics[metric])
result.append(np.mean(result).round(1)) # per-category mean
result = [str(res) for res in result]
print(' & '.join(result))

Expand Down

0 comments on commit 00a927d

Please sign in to comment.