From 1e3d06def7f40ae2f29526e00a852b7d807bf6b4 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Thu, 14 Dec 2023 16:01:38 +0800 Subject: [PATCH] [Feature] Support ONNX and TensorRT exportation of RTMO models (#2597) * support ONNX&TensorRT exportation of RTMO * add configs for rtmo * replace bbox expansion factor with parameter bbox_padding * refine code * refine comment * apply model.switch_to_deploy in BaseTask.build_pytorch_model * fix lint * add rtmo into regression test * add rtmo with trt backend into regression test * add rtmo into supported model list --- ...pose-detection_rtmo_onnxruntime_dynamic.py | 25 +++++ ...tion_rtmo_tensorrt-fp16_dynamic-640x640.py | 36 +++++++ docs/en/04-supported-codebases/mmpose.md | 1 + docs/zh_cn/04-supported-codebases/mmpose.md | 1 + mmdeploy/codebase/base/task.py | 5 + .../codebase/mmpose/models/heads/__init__.py | 4 +- .../codebase/mmpose/models/heads/rtmo_head.py | 100 ++++++++++++++++++ tests/regression/mmpose.yml | 14 +++ 8 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py create mode 100644 configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py create mode 100644 mmdeploy/codebase/mmpose/models/heads/rtmo_head.py diff --git a/configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py b/configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py new file mode 100644 index 0000000000..c1fbdaaeb0 --- /dev/null +++ b/configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py @@ -0,0 +1,25 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py'] + +onnx_config = dict( + output_names=['dets', 'keypoints'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'dets': { + 0: 'batch', + }, + 'keypoints': { + 0: 'batch' + } + }) + +codebase_config = dict( + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=2000, + keep_top_k=50, + background_label_id=-1, + )) diff --git a/configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py b/configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py new file mode 100644 index 0000000000..cedc8f7097 --- /dev/null +++ b/configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py @@ -0,0 +1,36 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt-fp16.py'] + +onnx_config = dict( + output_names=['dets', 'keypoints'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'dets': { + 0: 'batch', + }, + 'keypoints': { + 0: 'batch' + } + }) + +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 640, 640], + opt_shape=[1, 3, 640, 640], + max_shape=[1, 3, 640, 640]))) + ]) + +codebase_config = dict( + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=2000, + keep_top_k=50, + background_label_id=-1, + )) diff --git a/docs/en/04-supported-codebases/mmpose.md b/docs/en/04-supported-codebases/mmpose.md index 6f6ee4ab50..8c822cebc9 100644 --- a/docs/en/04-supported-codebases/mmpose.md +++ b/docs/en/04-supported-codebases/mmpose.md @@ -161,3 +161,4 @@ TODO | [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y | | [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y | | [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox_pose) | PoseDetection | Y | Y | N | N | Y | +| [RTMO](https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmo) | PoseDetection | Y | Y | N | N | N | diff --git a/docs/zh_cn/04-supported-codebases/mmpose.md b/docs/zh_cn/04-supported-codebases/mmpose.md index 961ba31f22..617dbd670c 100644 --- a/docs/zh_cn/04-supported-codebases/mmpose.md +++ b/docs/zh_cn/04-supported-codebases/mmpose.md @@ -165,3 +165,4 @@ task_processor.visualize( | [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y | | [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y | | [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox_pose) | PoseDetection | Y | Y | N | N | Y | +| [RTMO](https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmo) | PoseDetection | Y | Y | N | N | N | diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 048433e070..71efbf545e 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -126,6 +126,11 @@ def build_pytorch_model(self, if hasattr(model, 'backbone') and hasattr(model.backbone, 'switch_to_deploy'): model.backbone.switch_to_deploy() + + if hasattr(model, 'switch_to_deploy') and callable( + model.switch_to_deploy): + model.switch_to_deploy() + model = model.to(self.device) model.eval() return model diff --git a/mmdeploy/codebase/mmpose/models/heads/__init__.py b/mmdeploy/codebase/mmpose/models/heads/__init__.py index 10bd18a0d9..45ece714ad 100644 --- a/mmdeploy/codebase/mmpose/models/heads/__init__.py +++ b/mmdeploy/codebase/mmpose/models/heads/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403 +from . import mspn_head, rtmo_head, simcc_head, yolox_pose_head -__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head'] +__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head', 'rtmo_head'] diff --git a/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py new file mode 100644 index 0000000000..20bc748ac2 --- /dev/null +++ b/mmdeploy/codebase/mmpose/models/heads/rtmo_head.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +from mmpose.structures.bbox import bbox_xyxy2cs +from torch import Tensor + +from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.mmcv.ops.nms import multiclass_nms +from mmdeploy.utils import Backend, get_backend + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmpose.models.heads.hybrid_heads.' + 'rtmo_head.RTMOHead.forward') +def predict(self, + x: Tuple[Tensor], + batch_data_samples: List = [], + test_cfg: Optional[dict] = None): + """Get predictions and transform to bbox and keypoints results. + Args: + x (Tuple[Tensor]): The input tensor from upstream network. + batch_data_samples: Batch image meta info. Defaults to None. + test_cfg: The runtime config for testing process. + + Returns: + Tuple[Tensor]: Predict bbox and keypoint results. + - dets (Tensor): Predict bboxes and scores, which is a 3D Tensor, + has shape (batch_size, num_instances, 5), the last dimension 5 + arrange as (x1, y1, x2, y2, score). + - pred_kpts (Tensor): Predict keypoints and scores, which is a 4D + Tensor, has shape (batch_size, num_instances, num_keypoints, 5), + the last dimension 3 arrange as (x, y, score). + """ + + # deploy context + ctx = FUNCTION_REWRITER.get_context() + backend = get_backend(ctx.cfg) + deploy_cfg = ctx.cfg + + cfg = self.test_cfg if test_cfg is None else test_cfg + + # get predictions + cls_scores, bbox_preds, _, kpt_vis, pose_vecs = self.head_module(x)[:5] + assert len(cls_scores) == len(bbox_preds) + num_imgs = cls_scores[0].shape[0] + + # flatten and concat predictions + scores = self._flatten_predictions(cls_scores).sigmoid() + flatten_bbox_preds = self._flatten_predictions(bbox_preds) + flatten_pose_vecs = self._flatten_predictions(pose_vecs) + flatten_kpt_vis = self._flatten_predictions(kpt_vis).sigmoid() + bboxes = self.decode_bbox(flatten_bbox_preds, self.flatten_priors, + self.flatten_stride) + + if backend == Backend.TENSORRT: + # pad for batched_nms because its output index is filled with -1 + bboxes = torch.cat( + [bboxes, + bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))], + dim=1) + + scores = torch.cat( + [scores, scores.new_zeros((scores.shape[0], 1, 1))], dim=1) + + # nms parameters + post_params = get_post_processing_params(deploy_cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.get('nms_thr', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.get('pre_top_k', -1) + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + + # do nms + _, _, nms_indices = multiclass_nms( + bboxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + output_index=True) + + batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1) + + # filter predictions + dets = torch.cat([bboxes, scores], dim=2) + dets = dets[batch_inds, nms_indices, ...] + pose_vecs = flatten_pose_vecs[batch_inds, nms_indices, ...] + kpt_vis = flatten_kpt_vis[batch_inds, nms_indices, ...] + grids = self.flatten_priors[nms_indices, ...] + + # decode keypoints + bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], self.bbox_padding), dim=-1) + keypoints = self.dcc.forward_test(pose_vecs, bbox_cs, grids) + pred_kpts = torch.cat([keypoints, kpt_vis.unsqueeze(-1)], dim=-1) + + return dets, pred_kpts diff --git a/tests/regression/mmpose.yml b/tests/regression/mmpose.yml index ad3b7b8744..41554a6622 100644 --- a/tests/regression/mmpose.yml +++ b/tests/regression/mmpose.yml @@ -150,3 +150,17 @@ models: input_img: *img_human_pose test_img: *img_human_pose deploy_config: configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py + + - name: RTMO + metafile: configs/body_2d_keypoint/rtmo/body7/rtmo_body7.yml + model_configs: + - configs/body_2d_keypoint/rtmo/body7/rtmo-s_8xb32-600e_body7-640x640.py + pipelines: + - convert_image: + input_img: *img_human_pose + test_img: *img_human_pose + deploy_config: configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py + - convert_image: + input_img: *img_human_pose + test_img: *img_human_pose + deploy_config: configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py