From df0bf8fe77fc6b41ec0a9cffb6e6d74412626d89 Mon Sep 17 00:00:00 2001 From: huayuan4396 <110151316+huayuan4396@users.noreply.github.com> Date: Fri, 30 Jun 2023 19:59:01 +0800 Subject: [PATCH] Fix yolox-pose ut (#2231) * update yolox-pose ut * fix lint * fix --- .../test_mmpose/test_mmpose_models.py | 99 +++++++++++++++++-- 1 file changed, 90 insertions(+), 9 deletions(-) diff --git a/tests/test_codebase/test_mmpose/test_mmpose_models.py b/tests/test_codebase/test_mmpose/test_mmpose_models.py index 93721e129f..a83332a68e 100644 --- a/tests/test_codebase/test_mmpose/test_mmpose_models.py +++ b/tests/test_codebase/test_mmpose/test_mmpose_models.py @@ -2,6 +2,8 @@ import mmengine import pytest import torch +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase @@ -194,7 +196,9 @@ def test_scale_forward(backend_type: Backend): @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) def test_yolox_pose_head(backend_type: Backend): try: - from models import yolox_pose_head # noqa: F401,F403 + from mmyolo.utils.setup_env import register_all_modules + from models.yolox_pose_head import YOLOXPoseHead # noqa: F401,F403 + register_all_modules(True) except ImportError: pytest.skip( 'mmpose/projects/yolox-pose is not installed.', @@ -202,7 +206,8 @@ def test_yolox_pose_head(backend_type: Backend): deploy_cfg = mmengine.Config.fromfile( 'configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py') check_backend(backend_type, True) - model = yolox_pose_head.YOLOXPoseHead( + + head = YOLOXPoseHead( head_module=dict( type='YOLOXPoseHeadModule', num_classes=1, @@ -215,19 +220,95 @@ def test_yolox_pose_head(backend_type: Backend): use_depthwise=False, norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='SiLU', inplace=True), - )) + ), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.IoULoss', + mode='square', + eps=1e-16, + reduction='sum', + loss_weight=5.0), + loss_obj=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + loss_pose=dict( + type='OksLoss', + metainfo='configs/_base_/datasets/coco.py', + loss_weight=30.0), + loss_bbox_aux=dict( + type='mmdet.L1Loss', reduction='sum', loss_weight=1.0), + train_cfg=ConfigDict( + assigner=dict( + type='PoseSimOTAAssigner', + center_radius=2.5, + iou_calculator=dict(type='mmdet.BboxOverlaps2D'), + oks_calculator=dict( + type='OksLoss', + metainfo='configs/_base_/datasets/coco.py'))), + test_cfg=ConfigDict( + yolox_style=True, + multi_label=False, + score_thr=0.001, + max_per_img=300, + nms=dict(type='nms', iou_threshold=0.65))) + + class TestYOLOXPoseHeadModel(torch.nn.Module): + + def __init__(self, yolox_pose_head): + super(TestYOLOXPoseHeadModel, self).__init__() + self.yolox_pose_head = yolox_pose_head + + def forward(self, x1, x2, x3): + inputs = [x1, x2, x3] + data_sample = InstanceData() + data_sample.set_metainfo( + dict(ori_shape=(640, 640), scale_factor=(1.0, 1.0))) + return self.yolox_pose_head.predict( + inputs, batch_data_samples=[data_sample]) + + model = TestYOLOXPoseHeadModel(head) model.cpu().eval() + model_inputs = [ - torch.randn(2, 128, 80, 80), - torch.randn(2, 128, 40, 40), - torch.randn(2, 128, 20, 20) + torch.randn(1, 128, 8, 8), + torch.randn(1, 128, 4, 4), + torch.randn(1, 128, 2, 2) + ] + + with torch.no_grad(): + pytorch_output = model(*model_inputs)[0] + pred_bboxes = torch.from_numpy(pytorch_output.bboxes).unsqueeze(0) + pred_bboxes_scores = torch.from_numpy(pytorch_output.scores).reshape( + 1, -1, 1) + pred_kpts = torch.from_numpy(pytorch_output.keypoints).unsqueeze(0) + pred_kpts_scores = torch.from_numpy( + pytorch_output.keypoint_scores).unsqueeze(0).unsqueeze(-1) + + pytorch_output = [ + torch.cat([pred_bboxes, pred_bboxes_scores], dim=-1), + torch.cat([pred_kpts, pred_kpts_scores], dim=-1) ] - pytorch_output = model(model_inputs) + wrapped_model = WrapModel(model, 'forward') - rewrite_inputs = {'inputs': model_inputs} + rewrite_inputs = { + 'x1': model_inputs[0], + 'x2': model_inputs[1], + 'x3': model_inputs[2] + } + deploy_cfg.onnx_config.input_names = ['x1', 'x2', 'x3'] + rewrite_outputs, _ = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, - run_with_backend=False, + run_with_backend=True, deploy_cfg=deploy_cfg) + + # keep bbox coord >= 0 + rewrite_outputs[0] = rewrite_outputs[0].clamp(min=0) torch_assert_close(rewrite_outputs, pytorch_output)