Skip to content

Commit

Permalink
Support deploy of YoloX-Pose (#2184)
Browse files Browse the repository at this point in the history
* dev_mmpose

* tide

* fix lint

* del redundant task and model

* fix

* test ut

* test ut

* upload configs

* fix

* remove debug

* fix lint

* use mmcv.ops.nms

* fix lint

* remove loop

* debug

* test modified ut

* fix lint

* fix return type

* fix

* fix rescale

* fix

* fix pack_result

* update batch inference

* fix nms and pytorch show_box

* fix lint

* modify ut

* add docstring

* modify nms

* fix

* add openvino config

* update docs

* fix test_mmpose

---------

Co-authored-by: RunningLeon <[email protected]>
  • Loading branch information
huayuan4396 and RunningLeon committed Jun 28, 2023
1 parent a664f06 commit e19f6fa
Show file tree
Hide file tree
Showing 15 changed files with 455 additions and 21 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ jobs:
run: |
git clone -b dev --depth 1 https://github.com/open-mmlab/mmyolo.git /home/runner/work/mmyolo
python -m pip install -v -e /home/runner/work/mmyolo
- name: Install mmpose
run: |
git clone --depth 1 https://github.com/open-mmlab/mmpose.git /home/runner/work/mmpose
python -m pip install -v -e /home/runner/work/mmpose
- name: Build and install
run: |
rm -rf .eggs && python -m pip install -e .
Expand Down
25 changes: 25 additions & 0 deletions configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py']

onnx_config = dict(
output_names=['dets', 'keypoints'],
dynamic_axes={
'input': {
0: 'batch',
},
'dets': {
0: 'batch',
},
'keypoints': {
0: 'batch'
}
})

codebase_config = dict(
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/openvino.py']

onnx_config = dict(
output_names=['dets', 'keypoints'],
dynamic_axes={
'input': {
0: 'batch',
},
'dets': {
0: 'batch',
},
'keypoints': {
0: 'batch'
}
})
backend_config = dict(
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))])

codebase_config = dict(
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt.py']

onnx_config = dict(
output_names=['dets', 'keypoints'],
dynamic_axes={
'input': {
0: 'batch',
},
'dets': {
0: 'batch',
},
'keypoints': {
0: 'batch'
}
})
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 640, 640],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])

codebase_config = dict(
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ __launch_bounds__(nthds_per_cta) __global__
bboxOffset) *
5;
if (nmsedIndex != nullptr) {
nmsedIndex[i] = bboxId / 5;
nmsedIndex[i] = bboxId / 5 - bboxOffset;
}
// clipped bbox xmin
nmsedDets[i * 6] =
Expand All @@ -74,7 +74,7 @@ __launch_bounds__(nthds_per_cta) __global__
bboxOffset) *
4;
if (nmsedIndex != nullptr) {
nmsedIndex[i] = bboxId / 4;
nmsedIndex[i] = bboxId / 4 - bboxOffset;
}
// clipped bbox xmin
nmsedDets[i * 5] =
Expand Down
1 change: 1 addition & 0 deletions docs/en/04-supported-codebases/mmpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ TODO
| [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y |
1 change: 1 addition & 0 deletions docs/zh_cn/04-supported-codebases/mmpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,4 @@ task_processor.visualize(
| [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y |
25 changes: 21 additions & 4 deletions mmdeploy/codebase/mmpose/deploy/pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class MMPose(MMCodebase):
@classmethod
def register_deploy_modules(cls):
"""register rewritings."""
import mmdeploy.codebase.mmdet.models
import mmdeploy.codebase.mmdet.ops
import mmdeploy.codebase.mmdet.structures
import mmdeploy.codebase.mmpose.models # noqa: F401

@classmethod
Expand Down Expand Up @@ -202,9 +205,11 @@ def create_input(self,
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
img_path = [imgs]
else:
raise AssertionError('imgs must be strings or numpy arrays')
if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str):
img_path = imgs
img_data = [mmcv.imread(img) for img in imgs]
imgs = img_data
person_results = []
Expand All @@ -220,7 +225,7 @@ def create_input(self,
TRANSFORMS.build(c) for c in cfg.test_dataloader.dataset.pipeline
]
test_pipeline = Compose(test_pipeline)
if input_shape is not None:
if input_shape is not None and hasattr(cfg, 'codec'):
if isinstance(cfg.codec, dict):
codec = cfg.codec
elif isinstance(cfg.codec, list):
Expand All @@ -243,9 +248,15 @@ def create_input(self,
bbox_score = np.array([bbox[4] if len(bbox) == 5 else 1
]) # shape (1,)
data = {
'img': imgs[i],
'bbox_score': bbox_score,
'bbox': bbox[None], # shape (1, 4)
'img':
imgs[i],
'bbox_score':
bbox_score,
'bbox': [] if hasattr(cfg.model, 'bbox_head')
and cfg.model.bbox_head.type == 'YOLOXPoseHead' else
bbox[None],
'img_path':
img_path[i]
}
data.update(meta_data)
data = test_pipeline(data)
Expand Down Expand Up @@ -288,11 +299,17 @@ def visualize(self,

if isinstance(image, str):
image = mmcv.imread(image, channel_order='rgb')
draw_bbox = result.pred_instances.bboxes is not None
if draw_bbox and isinstance(result.pred_instances.bboxes,
torch.Tensor):
result.pred_instances.bboxes = result.pred_instances.bboxes.cpu(
).numpy()
visualizer.add_datasample(
name,
image,
data_sample=result,
draw_gt=False,
draw_bbox=draw_bbox,
show=show_result,
out_file=output_file)

Expand Down
57 changes: 54 additions & 3 deletions mmdeploy/codebase/mmpose/deploy/pose_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __init__(self,
device=device,
**kwargs)
# create head for decoding heatmap
self.head = builder.build_head(model_cfg.model.head)
self.head = builder.build_head(model_cfg.model.head) if hasattr(
model_cfg.model, 'head') else None

def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
device: str, **kwargs):
Expand Down Expand Up @@ -97,6 +98,9 @@ def forward(self,
inputs = inputs.contiguous().to(self.device)
batch_outputs = self.wrapper({self.input_name: inputs})
batch_outputs = self.wrapper.output_to_list(batch_outputs)
if self.model_cfg.model.type == 'YOLODetector':
return self.pack_yolox_pose_result(batch_outputs, data_samples)

codec = self.model_cfg.codec
if isinstance(codec, (list, tuple)):
codec = codec[-1]
Expand Down Expand Up @@ -158,6 +162,48 @@ def pack_result(self,

return data_samples

def pack_yolox_pose_result(self, preds: List[torch.Tensor],
data_samples: List[BaseDataElement]):
"""Pack yolox-pose prediction results to mmpose format
Args:
preds (List[Tensor]): Prediction of bboxes and key-points.
data_samples (List[BaseDataElement]): A list of meta info for
image(s).
Returns:
data_samples (List[BaseDataElement]):
updated data_samples with predictions.
"""
assert preds[0].shape[0] == len(data_samples)
batched_dets, batched_kpts = preds
for data_sample_idx, data_sample in enumerate(data_samples):
bboxes = batched_dets[data_sample_idx, :, :4]
bbox_scores = batched_dets[data_sample_idx, :, 4]
keypoints = batched_kpts[data_sample_idx, :, :, :2]
keypoint_scores = batched_kpts[data_sample_idx, :, :, 2]

# filter zero or negative scores
inds = bbox_scores > 0.0
bboxes = bboxes[inds, :]
bbox_scores = bbox_scores[inds]
keypoints = keypoints[inds, :]
keypoint_scores = keypoint_scores[inds]

pred_instances = InstanceData()
# rescale
scale_factor = data_sample.scale_factor
scale_factor = keypoints.new_tensor(scale_factor)
keypoints /= keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
bboxes /= keypoints.new_tensor(scale_factor).repeat(1, 2)
pred_instances.bboxes = bboxes.cpu().numpy()
pred_instances.bbox_scores = bbox_scores
# the precision test requires keypoints to be np.ndarray
pred_instances.keypoints = keypoints.cpu().numpy()
pred_instances.keypoint_scores = keypoint_scores
pred_instances.lebels = torch.zeros(bboxes.shape[0])

data_sample.pred_instances = pred_instances
return data_samples


@__BACKEND_MODEL.register_module('sdk')
class SDKEnd2EndModel(End2EndModel):
Expand Down Expand Up @@ -236,8 +282,13 @@ def build_pose_detection_model(
if isinstance(data_preprocessor, dict):
dp = data_preprocessor.copy()
dp_type = dp.pop('type')
assert dp_type == 'PoseDataPreprocessor'
data_preprocessor = PoseDataPreprocessor(**dp)
if dp_type == 'mmdet.DetDataPreprocessor':
from mmdet.models.data_preprocessors import DetDataPreprocessor
data_preprocessor = DetDataPreprocessor(**dp)
else:
assert dp_type == 'PoseDataPreprocessor'
data_preprocessor = PoseDataPreprocessor(**dp)

backend_pose_model = __BACKEND_MODEL.build(
dict(
type=model_type,
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmpose/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import mspn_head
from . import mspn_head, yolox_pose_head # noqa: F401,F403

__all__ = ['mspn_head']
__all__ = ['mspn_head', 'yolox_pose_head']
Loading

0 comments on commit e19f6fa

Please sign in to comment.