From 8dc1fd34f0befd2bf68b95475b802e53214aa797 Mon Sep 17 00:00:00 2001 From: Wuziyi616 Date: Thu, 4 Aug 2022 10:55:19 -0400 Subject: [PATCH] fix testing functions --- configs/_base_/datasets/breaking_bad/everyday.py | 5 +++++ scripts/collect_test.py | 15 +++++---------- scripts/sbatch_run.sh | 2 +- scripts/test.py | 14 +++++--------- scripts/train_everyday_categories.sh | 3 ++- scripts/train_one_category.sh | 3 ++- 6 files changed, 20 insertions(+), 22 deletions(-) diff --git a/configs/_base_/datasets/breaking_bad/everyday.py b/configs/_base_/datasets/breaking_bad/everyday.py index 5747e92..38a051f 100644 --- a/configs/_base_/datasets/breaking_bad/everyday.py +++ b/configs/_base_/datasets/breaking_bad/everyday.py @@ -14,6 +14,11 @@ _C.max_num_part = 20 _C.shuffle_parts = False _C.overfit = -1 +_C.all_category = [ + 'BeerBottle', 'Bowl', 'Cup', 'DrinkingUtensil', 'Mug', 'Plate', 'Spoon', + 'Teacup', 'ToyFigure', 'WineBottle', 'Bottle', 'Cookie', 'DrinkBottle', + 'Mirror', 'PillBottle', 'Ring', 'Statue', 'Teapot', 'Vase', 'WineGlass' +] _C.colors = [ [0, 204, 0], [204, 0, 0], diff --git a/scripts/collect_test.py b/scripts/collect_test.py index 7eb2489..97b474f 100644 --- a/scripts/collect_test.py +++ b/scripts/collect_test.py @@ -40,16 +40,11 @@ def test(cfg): ) # iterate over all per-category trained models - # TODO: currently we hard-code to support Breaking Bad dataset - all_category = [ - 'BeerBottle', 'Bowl', 'Cup', 'DrinkingUtensil', 'Mug', 'Plate', - 'Spoon', 'Teacup', 'ToyFigure', 'WineBottle', 'Bottle', 'Cookie', - 'DrinkBottle', 'Mirror', 'PillBottle', 'Ring', 'Statue', 'Teapot', - 'Vase', 'WineGlass' - ] + all_category = cfg.data.all_category all_metrics = { 'rot_rmse': 1., 'rot_mae': 1., + 'geo_rot': 1., 'trans_rmse': 100., 'trans_mae': 100., 'transform_pt_cd_loss': 1000., @@ -81,8 +76,7 @@ def test(cfg): results = model.test_results results = {k[5:]: v.cpu().numpy() for k, v in results.items()} for metric in all_metrics.keys(): - all_results[cat][metric].append(results[metric] * - all_metrics[metric]) + all_results[cat][metric].append(results[metric]) # average over `dup` runs for cat in all_category: for metric in all_metrics.keys(): @@ -97,7 +91,8 @@ def test(cfg): for metric, result in all_results.items(): print(f'{metric}:') result = result.tolist() - result.append(np.nanmean(result).round(1)) # per-category mean + # per-category mean, scale it for scientific notation + result.append(np.nanmean(result).round(1) * all_metrics[metric]) result = [str(res) for res in result] print(' & '.join(result)) diff --git a/scripts/sbatch_run.sh b/scripts/sbatch_run.sh index 34bada4..d3dc9de 100755 --- a/scripts/sbatch_run.sh +++ b/scripts/sbatch_run.sh @@ -68,5 +68,5 @@ python $PY_FILE $PY_ARGS >> $LOG_FILE # the script above, with it sbatch run-${SLRM_NAME}.slrm # delete it -sleep 1 +sleep 0.1 rm -f run-${SLRM_NAME}.slrm diff --git a/scripts/test.py b/scripts/test.py index 13d65ba..f0148ea 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -31,16 +31,11 @@ def test(cfg): return # if `args.category` is 'all', we also compute per-category results - # TODO: currently we hard-code to support Breaking Bad dataset - all_category = [ - 'BeerBottle', 'Bowl', 'Cup', 'DrinkingUtensil', 'Mug', 'Plate', - 'Spoon', 'Teacup', 'ToyFigure', 'WineBottle', 'Bottle', 'Cookie', - 'DrinkBottle', 'Mirror', 'PillBottle', 'Ring', 'Statue', 'Teapot', - 'Vase', 'WineGlass' - ] + all_category = cfg.data.all_category all_metrics = { 'rot_rmse': 1., 'rot_mae': 1., + 'geo_rot': 1., 'trans_rmse': 100., 'trans_mae': 100., 'transform_pt_cd_loss': 1000., @@ -55,13 +50,14 @@ def test(cfg): results = model.test_results results = {k[5:]: v.detach().cpu().numpy() for k, v in results.items()} for metric in all_metrics.keys(): - all_results[metric].append(results[metric] * all_metrics[metric]) + all_results[metric].append(results[metric]) all_results = {k: np.array(v).round(1) for k, v in all_results.items()} # format for latex table for metric, result in all_results.items(): print(f'{metric}:') result = result.tolist() - result.append(np.mean(result).round(1)) # per-category mean + # per-category mean, scale it for scientific notation + result.append(np.mean(result).round(1) * all_metrics[metric]) result = [str(res) for res in result] print(' & '.join(result)) diff --git a/scripts/train_everyday_categories.sh b/scripts/train_everyday_categories.sh index c634701..772d679 100755 --- a/scripts/train_everyday_categories.sh +++ b/scripts/train_everyday_categories.sh @@ -4,7 +4,7 @@ ####################################################################### # An example usage: -# ./scripts/train_everyday_categories.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 everyday_cat ./scripts/train.py config.py --fp16 --cudnn" config.py +# ./scripts/train_everyday_categories.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 global-everyday-xxx-CATEGORY ./scripts/train.py config.py --fp16 --cudnn" config.py ####################################################################### CMD=$1 @@ -15,6 +15,7 @@ do cfg="${CFG:0:(-3)}-$cat.py" cp $CFG $cfg cmd="${CMD/$CFG/$cfg}" + cmd="${cmd/CATEGORY/$cat}" cmd="$cmd --category $cat" eval $cmd done diff --git a/scripts/train_one_category.sh b/scripts/train_one_category.sh index 28aeb1a..9c68f26 100755 --- a/scripts/train_one_category.sh +++ b/scripts/train_one_category.sh @@ -4,7 +4,7 @@ ####################################################################### # An example usage: -# ./scripts/train_one_category.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 everyday_cat ./scripts/train.py config.py --fp16 --cudnn" config.py Bottle +# ./scripts/train_one_category.sh "GPUS=1 CPUS_PER_TASK=8 MEM_PER_CPU=5 QOS=normal REPEAT=3 ./scripts/dup_run_sbatch.sh rtx6000 global-everyday-xxx-CATEGORY ./scripts/train.py config.py --fp16 --cudnn" config.py Bottle ####################################################################### CMD=$1 @@ -14,5 +14,6 @@ cat=$3 cfg="${CFG:0:(-3)}-$cat.py" cp $CFG $cfg cmd="${CMD/$CFG/$cfg}" +cmd="${cmd/CATEGORY/$cat}" cmd="$cmd --category $cat" eval $cmd