Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update yolox-pose ut #2231

Merged
merged 3 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def yolox_pose_head__predict_by_feat(
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.get('pre_top_k', -1)
keep_top_k = post_params.get('keep_top_k', -1)
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
# do nms
_, _, nms_indices = multiclass_nms(
bboxes,
Expand Down
99 changes: 90 additions & 9 deletions tests/test_codebase/test_mmpose/test_mmpose_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import mmengine
import pytest
import torch
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData

from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase
Expand Down Expand Up @@ -194,15 +196,18 @@ def test_scale_forward(backend_type: Backend):
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_yolox_pose_head(backend_type: Backend):
try:
from models import yolox_pose_head # noqa: F401,F403
from mmyolo.utils.setup_env import register_all_modules
from models.yolox_pose_head import YOLOXPoseHead # noqa: F401,F403
register_all_modules(True)
except ImportError:
pytest.skip(
'mmpose/projects/yolox-pose is not installed.',
allow_module_level=True)
deploy_cfg = mmengine.Config.fromfile(
'configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py')
check_backend(backend_type, True)
model = yolox_pose_head.YOLOXPoseHead(

head = YOLOXPoseHead(
head_module=dict(
type='YOLOXPoseHeadModule',
num_classes=1,
Expand All @@ -215,19 +220,95 @@ def test_yolox_pose_head(backend_type: Backend):
use_depthwise=False,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
))
),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_bbox=dict(
type='mmdet.IoULoss',
mode='square',
eps=1e-16,
reduction='sum',
loss_weight=5.0),
loss_obj=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
loss_pose=dict(
type='OksLoss',
metainfo='configs/_base_/datasets/coco.py',
loss_weight=30.0),
loss_bbox_aux=dict(
type='mmdet.L1Loss', reduction='sum', loss_weight=1.0),
train_cfg=ConfigDict(
assigner=dict(
type='PoseSimOTAAssigner',
center_radius=2.5,
iou_calculator=dict(type='mmdet.BboxOverlaps2D'),
oks_calculator=dict(
type='OksLoss',
metainfo='configs/_base_/datasets/coco.py'))),
test_cfg=ConfigDict(
yolox_style=True,
multi_label=False,
score_thr=0.001,
max_per_img=300,
nms=dict(type='nms', iou_threshold=0.65)))

class TestYOLOXPoseHeadModel(torch.nn.Module):

def __init__(self, yolox_pose_head):
super(TestYOLOXPoseHeadModel, self).__init__()
self.yolox_pose_head = yolox_pose_head

def forward(self, x1, x2, x3):
inputs = [x1, x2, x3]
data_sample = InstanceData()
data_sample.set_metainfo(
dict(ori_shape=(640, 640), scale_factor=(1.0, 1.0)))
return self.yolox_pose_head.predict(
inputs, batch_data_samples=[data_sample])

model = TestYOLOXPoseHeadModel(head)
model.cpu().eval()

model_inputs = [
torch.randn(2, 128, 80, 80),
torch.randn(2, 128, 40, 40),
torch.randn(2, 128, 20, 20)
torch.randn(1, 128, 8, 8),
torch.randn(1, 128, 4, 4),
torch.randn(1, 128, 2, 2)
]

with torch.no_grad():
pytorch_output = model(*model_inputs)[0]
pred_bboxes = torch.from_numpy(pytorch_output.bboxes).unsqueeze(0)
pred_bboxes_scores = torch.from_numpy(pytorch_output.scores).reshape(
1, -1, 1)
pred_kpts = torch.from_numpy(pytorch_output.keypoints).unsqueeze(0)
pred_kpts_scores = torch.from_numpy(
pytorch_output.keypoint_scores).unsqueeze(0).unsqueeze(-1)

pytorch_output = [
torch.cat([pred_bboxes, pred_bboxes_scores], dim=-1),
torch.cat([pred_kpts, pred_kpts_scores], dim=-1)
]
pytorch_output = model(model_inputs)

wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {'inputs': model_inputs}
rewrite_inputs = {
'x1': model_inputs[0],
'x2': model_inputs[1],
'x3': model_inputs[2]
}
deploy_cfg.onnx_config.input_names = ['x1', 'x2', 'x3']

rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
run_with_backend=False,
run_with_backend=True,
deploy_cfg=deploy_cfg)

# keep bbox coord >= 0
rewrite_outputs[0] = rewrite_outputs[0].clamp(min=0)
torch_assert_close(rewrite_outputs, pytorch_output)