From 3ab17f5367063313a6f54b1d5eecdf13939935df Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Fri, 30 Jun 2023 19:47:38 +0800 Subject: [PATCH] Fix reg test for maskrcnn (#2230) --- .../codebase/mmpose/models/heads/yolox_pose_head.py | 11 +++++------ tools/regression_test.py | 5 +++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py b/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py index 1866dd215b..1f061cc945 100644 --- a/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py @@ -8,8 +8,7 @@ from mmdeploy.codebase.mmdet import get_post_processing_params from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.mmcv.ops.nms import multiclass_nms -from mmdeploy.utils import Backend -from mmdeploy.utils.config_utils import get_backend_config +from mmdeploy.utils import Backend, get_backend @FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.' @@ -166,9 +165,9 @@ def yolox_pose_head__predict_by_feat( pred_kpts = torch.cat([flatten_decoded_kpts, vis_preds], dim=3) - backend_config = get_backend_config(deploy_cfg) - if backend_config.type == Backend.TENSORRT.value: - # pad + backend = get_backend(deploy_cfg) + if backend == Backend.TENSORRT: + # pad for batched_nms because its output index is filled with -1 bboxes = torch.cat( [bboxes, bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))], @@ -188,7 +187,7 @@ def yolox_pose_head__predict_by_feat( iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) score_threshold = cfg.get('score_thr', post_params.score_threshold) pre_top_k = post_params.get('pre_top_k', -1) - keep_top_k = post_params.get('keep_top_k', -1) + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) # do nms _, _, nms_indices = multiclass_nms( bboxes, diff --git a/tools/regression_test.py b/tools/regression_test.py index 0cf289f146..10069f7f82 100644 --- a/tools/regression_test.py +++ b/tools/regression_test.py @@ -304,8 +304,9 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path, task_name = metafile_metric['Task'] dataset = metafile_metric['Dataset'] - # check if metafile use the same metric on several datasets - if len(metafile_metric_info) > 1: + # check if metafile use the same metric on several datasets for mmagic + task_info = set([_['Task'] for _ in metafile_metric_info]) + if len(metafile_metric_info) > 1 and len(task_info) == 1: for k, v in metafile_metric['Metrics'].items(): pytorch_metric[f'{dataset} {k}'] = v else: