Skip to content

Commit

Permalink
Fix reg test for maskrcnn (#2230)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Jun 30, 2023
1 parent e19f6fa commit 3ab17f5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
11 changes: 5 additions & 6 deletions mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmcv.ops.nms import multiclass_nms
from mmdeploy.utils import Backend
from mmdeploy.utils.config_utils import get_backend_config
from mmdeploy.utils import Backend, get_backend


@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
Expand Down Expand Up @@ -166,9 +165,9 @@ def yolox_pose_head__predict_by_feat(

pred_kpts = torch.cat([flatten_decoded_kpts, vis_preds], dim=3)

backend_config = get_backend_config(deploy_cfg)
if backend_config.type == Backend.TENSORRT.value:
# pad
backend = get_backend(deploy_cfg)
if backend == Backend.TENSORRT:
# pad for batched_nms because its output index is filled with -1
bboxes = torch.cat(
[bboxes,
bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))],
Expand All @@ -188,7 +187,7 @@ def yolox_pose_head__predict_by_feat(
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.get('pre_top_k', -1)
keep_top_k = post_params.get('keep_top_k', -1)
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# do nms
_, _, nms_indices = multiclass_nms(
bboxes,
Expand Down
5 changes: 3 additions & 2 deletions tools/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,9 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
task_name = metafile_metric['Task']
dataset = metafile_metric['Dataset']

# check if metafile use the same metric on several datasets
if len(metafile_metric_info) > 1:
# check if metafile use the same metric on several datasets for mmagic
task_info = set([_['Task'] for _ in metafile_metric_info])
if len(metafile_metric_info) > 1 and len(task_info) == 1:
for k, v in metafile_metric['Metrics'].items():
pytorch_metric[f'{dataset} {k}'] = v
else:
Expand Down

0 comments on commit 3ab17f5

Please sign in to comment.