Skip to content

Commit

Permalink
Fix yolox-pose ut (#2231)
Browse files Browse the repository at this point in the history
* update yolox-pose ut

* fix lint

* fix
  • Loading branch information
huayuan4396 committed Jun 30, 2023
1 parent 3ab17f5 commit df0bf8f
Showing 1 changed file with 90 additions and 9 deletions.
99 changes: 90 additions & 9 deletions tests/test_codebase/test_mmpose/test_mmpose_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import mmengine
import pytest
import torch
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData

from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase
Expand Down Expand Up @@ -194,15 +196,18 @@ def test_scale_forward(backend_type: Backend):
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_yolox_pose_head(backend_type: Backend):
try:
from models import yolox_pose_head # noqa: F401,F403
from mmyolo.utils.setup_env import register_all_modules
from models.yolox_pose_head import YOLOXPoseHead # noqa: F401,F403
register_all_modules(True)
except ImportError:
pytest.skip(
'mmpose/projects/yolox-pose is not installed.',
allow_module_level=True)
deploy_cfg = mmengine.Config.fromfile(
'configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py')
check_backend(backend_type, True)
model = yolox_pose_head.YOLOXPoseHead(

head = YOLOXPoseHead(
head_module=dict(
type='YOLOXPoseHeadModule',
num_classes=1,
Expand All @@ -215,19 +220,95 @@ def test_yolox_pose_head(backend_type: Backend):
use_depthwise=False,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
))
),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_bbox=dict(
type='mmdet.IoULoss',
mode='square',
eps=1e-16,
reduction='sum',
loss_weight=5.0),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_pose=dict(
type='OksLoss',
metainfo='configs/_base_/datasets/coco.py',
loss_weight=30.0),
loss_bbox_aux=dict(
type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
train_cfg=ConfigDict(
assigner=dict(
type='PoseSimOTAAssigner',
center_radius=2.5,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
oks_calculator=dict(
type='OksLoss',
metainfo='configs/_base_/datasets/coco.py'))),
test_cfg=ConfigDict(
yolox_style=True,
multi_label=False,
score_thr=0.001,
max_per_img=300,
nms=dict(type='nms', iou_threshold=0.65)))

class TestYOLOXPoseHeadModel(torch.nn.Module):

def __init__(self, yolox_pose_head):
super(TestYOLOXPoseHeadModel, self).__init__()
self.yolox_pose_head = yolox_pose_head

def forward(self, x1, x2, x3):
inputs = [x1, x2, x3]
data_sample = InstanceData()
data_sample.set_metainfo(
dict(ori_shape=(640, 640), scale_factor=(1.0, 1.0)))
return self.yolox_pose_head.predict(
inputs, batch_data_samples=[data_sample])

model = TestYOLOXPoseHeadModel(head)
model.cpu().eval()

model_inputs = [
torch.randn(2, 128, 80, 80),
torch.randn(2, 128, 40, 40),
torch.randn(2, 128, 20, 20)
torch.randn(1, 128, 8, 8),
torch.randn(1, 128, 4, 4),
torch.randn(1, 128, 2, 2)
]

with torch.no_grad():
pytorch_output = model(*model_inputs)[0]
pred_bboxes = torch.from_numpy(pytorch_output.bboxes).unsqueeze(0)
pred_bboxes_scores = torch.from_numpy(pytorch_output.scores).reshape(
1, -1, 1)
pred_kpts = torch.from_numpy(pytorch_output.keypoints).unsqueeze(0)
pred_kpts_scores = torch.from_numpy(
pytorch_output.keypoint_scores).unsqueeze(0).unsqueeze(-1)

pytorch_output = [
torch.cat([pred_bboxes, pred_bboxes_scores], dim=-1),
torch.cat([pred_kpts, pred_kpts_scores], dim=-1)
]
pytorch_output = model(model_inputs)

wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {'inputs': model_inputs}
rewrite_inputs = {
'x1': model_inputs[0],
'x2': model_inputs[1],
'x3': model_inputs[2]
}
deploy_cfg.onnx_config.input_names = ['x1', 'x2', 'x3']

rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
run_with_backend=False,
run_with_backend=True,
deploy_cfg=deploy_cfg)

# keep bbox coord >= 0
rewrite_outputs[0] = rewrite_outputs[0].clamp(min=0)
torch_assert_close(rewrite_outputs, pytorch_output)

0 comments on commit df0bf8f

Please sign in to comment.