Skip to content

Commit

Permalink
fix testing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuziyi616 committed Aug 4, 2022
1 parent fa0fcf5 commit 8dc1fd3
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 22 deletions.
5 changes: 5 additions & 0 deletions configs/_base_/datasets/breaking_bad/everyday.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
_C.max_num_part = 20
_C.shuffle_parts = False
_C.overfit = -1
_C.all_category = [
'BeerBottle', 'Bowl', 'Cup', 'DrinkingUtensil', 'Mug', 'Plate', 'Spoon',
'Teacup', 'ToyFigure', 'WineBottle', 'Bottle', 'Cookie', 'DrinkBottle',
'Mirror', 'PillBottle', 'Ring', 'Statue', 'Teapot', 'Vase', 'WineGlass'
]
_C.colors = [
[0, 204, 0],
[204, 0, 0],
Expand Down
15 changes: 5 additions & 10 deletions scripts/collect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,11 @@ def test(cfg):
)

# iterate over all per-category trained models
# TODO: currently we hard-code to support Breaking Bad dataset
all_category = [
'BeerBottle', 'Bowl', 'Cup', 'DrinkingUtensil', 'Mug', 'Plate',
'Spoon', 'Teacup', 'ToyFigure', 'WineBottle', 'Bottle', 'Cookie',
'DrinkBottle', 'Mirror', 'PillBottle', 'Ring', 'Statue', 'Teapot',
'Vase', 'WineGlass'
]
all_category = cfg.data.all_category
all_metrics = {
'rot_rmse': 1.,
'rot_mae': 1.,
'geo_rot': 1.,
'trans_rmse': 100.,
'trans_mae': 100.,
'transform_pt_cd_loss': 1000.,
Expand Down Expand Up @@ -81,8 +76,7 @@ def test(cfg):
results = model.test_results
results = {k[5:]: v.cpu().numpy() for k, v in results.items()}
for metric in all_metrics.keys():
all_results[cat][metric].append(results[metric] *
all_metrics[metric])
all_results[cat][metric].append(results[metric])
# average over `dup` runs
for cat in all_category:
for metric in all_metrics.keys():
Expand All @@ -97,7 +91,8 @@ def test(cfg):
for metric, result in all_results.items():
print(f'{metric}:')
result = result.tolist()
result.append(np.nanmean(result).round(1)) # per-category mean
# per-category mean, scale it for scientific notation
result.append(np.nanmean(result).round(1) * all_metrics[metric])
result = [str(res) for res in result]
print(' & '.join(result))

Expand Down
2 changes: 1 addition & 1 deletion scripts/sbatch_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,5 @@ python $PY_FILE $PY_ARGS >> $LOG_FILE # the script above, with it
sbatch run-${SLRM_NAME}.slrm

# delete it
sleep 1
sleep 0.1
rm -f run-${SLRM_NAME}.slrm
14 changes: 5 additions & 9 deletions scripts/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,11 @@ def test(cfg):
return

# if `args.category` is 'all', we also compute per-category results
# TODO: currently we hard-code to support Breaking Bad dataset
all_category = [
'BeerBottle', 'Bowl', 'Cup', 'DrinkingUtensil', 'Mug', 'Plate',
'Spoon', 'Teacup', 'ToyFigure', 'WineBottle', 'Bottle', 'Cookie',
'DrinkBottle', 'Mirror', 'PillBottle', 'Ring', 'Statue', 'Teapot',
'Vase', 'WineGlass'
]
all_category = cfg.data.all_category
all_metrics = {
'rot_rmse': 1.,
'rot_mae': 1.,
'geo_rot': 1.,
'trans_rmse': 100.,
'trans_mae': 100.,
'transform_pt_cd_loss': 1000.,
Expand All @@ -55,13 +50,14 @@ def test(cfg):
results = model.test_results
results = {k[5:]: v.detach().cpu().numpy() for k, v in results.items()}
for metric in all_metrics.keys():
all_results[metric].append(results[metric] * all_metrics[metric])
all_results[metric].append(results[metric])
all_results = {k: np.array(v).round(1) for k, v in all_results.items()}
# format for latex table
for metric, result in all_results.items():
print(f'{metric}:')
result = result.tolist()
result.append(np.mean(result).round(1)) # per-category mean
# per-category mean, scale it for scientific notation
result.append(np.mean(result).round(1) * all_metrics[metric])
result = [str(res) for res in result]
print(' & '.join(result))

Expand Down
3 changes: 2 additions & 1 deletion scripts/train_everyday_categories.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#######################################################################
# An example usage:
# ./scripts/train_everyday_categories.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 everyday_cat ./scripts/train.py config.py --fp16 --cudnn" config.py
# ./scripts/train_everyday_categories.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 global-everyday-xxx-CATEGORY ./scripts/train.py config.py --fp16 --cudnn" config.py
#######################################################################

CMD=$1
Expand All @@ -15,6 +15,7 @@ do
cfg="${CFG:0:(-3)}-$cat.py"
cp $CFG $cfg
cmd="${CMD/$CFG/$cfg}"
cmd="${cmd/CATEGORY/$cat}"
cmd="$cmd --category $cat"
eval $cmd
done
3 changes: 2 additions & 1 deletion scripts/train_one_category.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#######################################################################
# An example usage:
# ./scripts/train_one_category.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 everyday_cat ./scripts/train.py config.py --fp16 --cudnn" config.py Bottle
# ./scripts/train_one_category.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 global-everyday-xxx-CATEGORY ./scripts/train.py config.py --fp16 --cudnn" config.py Bottle
#######################################################################

CMD=$1
Expand All @@ -14,5 +14,6 @@ cat=$3
cfg="${CFG:0:(-3)}-$cat.py"
cp $CFG $cfg
cmd="${CMD/$CFG/$cfg}"
cmd="${cmd/CATEGORY/$cat}"
cmd="$cmd --category $cat"
eval $cmd

0 comments on commit 8dc1fd3

Please sign in to comment.