diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bd63e96872..7217e68c93 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -52,6 +52,10 @@ jobs: run: | git clone -b dev --depth 1 https://github.com/open-mmlab/mmyolo.git /home/runner/work/mmyolo python -m pip install -v -e /home/runner/work/mmyolo + - name: Install mmpose + run: | + git clone --depth 1 https://github.com/open-mmlab/mmpose.git /home/runner/work/mmpose + python -m pip install -v -e /home/runner/work/mmpose - name: Build and install run: | rm -rf .eggs && python -m pip install -e . diff --git a/configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py b/configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py new file mode 100644 index 0000000000..38f17d7d10 --- /dev/null +++ b/configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py @@ -0,0 +1,25 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py'] + +onnx_config = dict( + output_names=['dets', 'keypoints'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'dets': { + 0: 'batch', + }, + 'keypoints': { + 0: 'batch' + } + }) + +codebase_config = dict( + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )) diff --git a/configs/mmpose/pose-detection_yolox-pose_openvino_dynamic-640x640.py b/configs/mmpose/pose-detection_yolox-pose_openvino_dynamic-640x640.py new file mode 100644 index 0000000000..099994a225 --- /dev/null +++ b/configs/mmpose/pose-detection_yolox-pose_openvino_dynamic-640x640.py @@ -0,0 +1,27 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/openvino.py'] + +onnx_config = dict( + output_names=['dets', 'keypoints'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'dets': { + 0: 'batch', + }, + 'keypoints': { + 0: 'batch' + } + }) +backend_config = dict( + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))]) + +codebase_config = dict( + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )) diff --git a/configs/mmpose/pose-detection_yolox-pose_tensorrt_dynamic-640x640.py b/configs/mmpose/pose-detection_yolox-pose_tensorrt_dynamic-640x640.py new file mode 100644 index 0000000000..e50e085e0c --- /dev/null +++ b/configs/mmpose/pose-detection_yolox-pose_tensorrt_dynamic-640x640.py @@ -0,0 +1,35 @@ +_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt.py'] + +onnx_config = dict( + output_names=['dets', 'keypoints'], + dynamic_axes={ + 'input': { + 0: 'batch', + }, + 'dets': { + 0: 'batch', + }, + 'keypoints': { + 0: 'batch' + } + }) +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 640, 640], + opt_shape=[1, 3, 640, 640], + max_shape=[1, 3, 640, 640]))) + ]) + +codebase_config = dict( + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )) diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu index 8a0ec7bac8..58419f8c16 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu @@ -50,7 +50,7 @@ __launch_bounds__(nthds_per_cta) __global__ bboxOffset) * 5; if (nmsedIndex != nullptr) { - nmsedIndex[i] = bboxId / 5; + nmsedIndex[i] = bboxId / 5 - bboxOffset; } // clipped bbox xmin nmsedDets[i * 6] = @@ -74,7 +74,7 @@ __launch_bounds__(nthds_per_cta) __global__ bboxOffset) * 4; if (nmsedIndex != nullptr) { - nmsedIndex[i] = bboxId / 4; + nmsedIndex[i] = bboxId / 4 - bboxOffset; } // clipped bbox xmin nmsedDets[i * 5] = diff --git a/docs/en/04-supported-codebases/mmpose.md b/docs/en/04-supported-codebases/mmpose.md index be0fc19a98..e990cd3310 100644 --- a/docs/en/04-supported-codebases/mmpose.md +++ b/docs/en/04-supported-codebases/mmpose.md @@ -160,3 +160,4 @@ TODO | [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y | | [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y | | [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y | +| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y | diff --git a/docs/zh_cn/04-supported-codebases/mmpose.md b/docs/zh_cn/04-supported-codebases/mmpose.md index f6529da2c4..0f64fc65cb 100644 --- a/docs/zh_cn/04-supported-codebases/mmpose.md +++ b/docs/zh_cn/04-supported-codebases/mmpose.md @@ -164,3 +164,4 @@ task_processor.visualize( | [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y | | [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y | | [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y | +| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y | diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index f3b5b6dede..6584d995fb 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -120,6 +120,9 @@ class MMPose(MMCodebase): @classmethod def register_deploy_modules(cls): """register rewritings.""" + import mmdeploy.codebase.mmdet.models + import mmdeploy.codebase.mmdet.ops + import mmdeploy.codebase.mmdet.structures import mmdeploy.codebase.mmpose.models # noqa: F401 @classmethod @@ -202,9 +205,11 @@ def create_input(self, raise AssertionError('imgs must be strings or numpy arrays') elif isinstance(imgs, (np.ndarray, str)): imgs = [imgs] + img_path = [imgs] else: raise AssertionError('imgs must be strings or numpy arrays') if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str): + img_path = imgs img_data = [mmcv.imread(img) for img in imgs] imgs = img_data person_results = [] @@ -220,7 +225,7 @@ def create_input(self, TRANSFORMS.build(c) for c in cfg.test_dataloader.dataset.pipeline ] test_pipeline = Compose(test_pipeline) - if input_shape is not None: + if input_shape is not None and hasattr(cfg, 'codec'): if isinstance(cfg.codec, dict): codec = cfg.codec elif isinstance(cfg.codec, list): @@ -243,9 +248,15 @@ def create_input(self, bbox_score = np.array([bbox[4] if len(bbox) == 5 else 1 ]) # shape (1,) data = { - 'img': imgs[i], - 'bbox_score': bbox_score, - 'bbox': bbox[None], # shape (1, 4) + 'img': + imgs[i], + 'bbox_score': + bbox_score, + 'bbox': [] if hasattr(cfg.model, 'bbox_head') + and cfg.model.bbox_head.type == 'YOLOXPoseHead' else + bbox[None], + 'img_path': + img_path[i] } data.update(meta_data) data = test_pipeline(data) @@ -288,11 +299,17 @@ def visualize(self, if isinstance(image, str): image = mmcv.imread(image, channel_order='rgb') + draw_bbox = result.pred_instances.bboxes is not None + if draw_bbox and isinstance(result.pred_instances.bboxes, + torch.Tensor): + result.pred_instances.bboxes = result.pred_instances.bboxes.cpu( + ).numpy() visualizer.add_datasample( name, image, data_sample=result, draw_gt=False, + draw_bbox=draw_bbox, show=show_result, out_file=output_file) diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py index ca7926817e..1686a089fe 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py @@ -54,7 +54,8 @@ def __init__(self, device=device, **kwargs) # create head for decoding heatmap - self.head = builder.build_head(model_cfg.model.head) + self.head = builder.build_head(model_cfg.model.head) if hasattr( + model_cfg.model, 'head') else None def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], device: str, **kwargs): @@ -97,6 +98,9 @@ def forward(self, inputs = inputs.contiguous().to(self.device) batch_outputs = self.wrapper({self.input_name: inputs}) batch_outputs = self.wrapper.output_to_list(batch_outputs) + if self.model_cfg.model.type == 'YOLODetector': + return self.pack_yolox_pose_result(batch_outputs, data_samples) + codec = self.model_cfg.codec if isinstance(codec, (list, tuple)): codec = codec[-1] @@ -158,6 +162,48 @@ def pack_result(self, return data_samples + def pack_yolox_pose_result(self, preds: List[torch.Tensor], + data_samples: List[BaseDataElement]): + """Pack yolox-pose prediction results to mmpose format + Args: + preds (List[Tensor]): Prediction of bboxes and key-points. + data_samples (List[BaseDataElement]): A list of meta info for + image(s). + Returns: + data_samples (List[BaseDataElement]): + updated data_samples with predictions. + """ + assert preds[0].shape[0] == len(data_samples) + batched_dets, batched_kpts = preds + for data_sample_idx, data_sample in enumerate(data_samples): + bboxes = batched_dets[data_sample_idx, :, :4] + bbox_scores = batched_dets[data_sample_idx, :, 4] + keypoints = batched_kpts[data_sample_idx, :, :, :2] + keypoint_scores = batched_kpts[data_sample_idx, :, :, 2] + + # filter zero or negative scores + inds = bbox_scores > 0.0 + bboxes = bboxes[inds, :] + bbox_scores = bbox_scores[inds] + keypoints = keypoints[inds, :] + keypoint_scores = keypoint_scores[inds] + + pred_instances = InstanceData() + # rescale + scale_factor = data_sample.scale_factor + scale_factor = keypoints.new_tensor(scale_factor) + keypoints /= keypoints.new_tensor(scale_factor).reshape(1, 1, 2) + bboxes /= keypoints.new_tensor(scale_factor).repeat(1, 2) + pred_instances.bboxes = bboxes.cpu().numpy() + pred_instances.bbox_scores = bbox_scores + # the precision test requires keypoints to be np.ndarray + pred_instances.keypoints = keypoints.cpu().numpy() + pred_instances.keypoint_scores = keypoint_scores + pred_instances.lebels = torch.zeros(bboxes.shape[0]) + + data_sample.pred_instances = pred_instances + return data_samples + @__BACKEND_MODEL.register_module('sdk') class SDKEnd2EndModel(End2EndModel): @@ -236,8 +282,13 @@ def build_pose_detection_model( if isinstance(data_preprocessor, dict): dp = data_preprocessor.copy() dp_type = dp.pop('type') - assert dp_type == 'PoseDataPreprocessor' - data_preprocessor = PoseDataPreprocessor(**dp) + if dp_type == 'mmdet.DetDataPreprocessor': + from mmdet.models.data_preprocessors import DetDataPreprocessor + data_preprocessor = DetDataPreprocessor(**dp) + else: + assert dp_type == 'PoseDataPreprocessor' + data_preprocessor = PoseDataPreprocessor(**dp) + backend_pose_model = __BACKEND_MODEL.build( dict( type=model_type, diff --git a/mmdeploy/codebase/mmpose/models/heads/__init__.py b/mmdeploy/codebase/mmpose/models/heads/__init__.py index d42abf17f8..9fb6239cdb 100644 --- a/mmdeploy/codebase/mmpose/models/heads/__init__.py +++ b/mmdeploy/codebase/mmpose/models/heads/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import mspn_head +from . import mspn_head, yolox_pose_head # noqa: F401,F403 -__all__ = ['mspn_head'] +__all__ = ['mspn_head', 'yolox_pose_head'] diff --git a/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py b/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py new file mode 100644 index 0000000000..1866dd215b --- /dev/null +++ b/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py @@ -0,0 +1,208 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +from mmengine.config import ConfigDict +from torch import Tensor + +from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.mmcv.ops.nms import multiclass_nms +from mmdeploy.utils import Backend +from mmdeploy.utils.config_utils import get_backend_config + + +@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.' + 'YOLOXPoseHead.predict') +def predict(self, + x: Tuple[Tensor], + batch_data_samples=None, + rescale: bool = True): + """Get predictions and transform to bbox and keypoints results. + Args: + x (Tuple[Tensor]): The input tensor from upstream network. + batch_data_samples: Batch image meta info. Defaults to None. + rescale: If True, return boxes in original image space. + Defaults to False. + + Returns: + Tuple[Tensor]: Predict bbox and keypoint results. + - dets (Tensor): Predict bboxes and scores, which is a 3D Tensor, + has shape (batch_size, num_instances, 5), the last dimension 5 + arrange as (x1, y1, x2, y2, score). + - pred_kpts (Tensor): Predict keypoints and scores, which is a 4D + Tensor, has shape (batch_size, num_instances, num_keypoints, 5), + the last dimension 3 arrange as (x, y, score). + """ + outs = self(x) + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_data_samples, rescale=rescale) + return predictions + + +@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.' + 'YOLOXPoseHead.predict_by_feat') +def yolox_pose_head__predict_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + objectnesses: Optional[List[Tensor]] = None, + kpt_preds: Optional[List[Tensor]] = None, + vis_preds: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = True, + with_nms: bool = True) -> Tuple[Tensor]: + """Transform a batch of output features extracted by the head into bbox and + keypoint results. + + In addition to the base class method, keypoint predictions are also + calculated in this method. + + Args: + cls_scores (List[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (List[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + objectnesses (Optional[List[Tensor]]): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, 1, H, W). + kpt_preds (Optional[List[Tensor]]): Keypoints for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_keypoints * 2, H, W) + vis_preds (Optional[List[Tensor]]): Keypoints scores for + all scale levels, each is a 4D-tensor, has shape + (batch_size, num_keypoints, H, W) + batch_img_metas (Optional[List[dict]]): Batch image meta + info. Defaults to None. + cfg (Optional[ConfigDict]): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + Returns: + Tuple[Tensor]: Predict bbox and keypoint results. + - dets (Tensor): Predict bboxes and scores, which is a 3D Tensor, + has shape (batch_size, num_instances, 5), the last dimension 5 + arrange as (x1, y1, x2, y2, score). + - pred_kpts (Tensor): Predict keypoints and scores, which is a 4D + Tensor, has shape (batch_size, num_instances, num_keypoints, 5), + the last dimension 3 arrange as (x, y, score). + """ + ctx = FUNCTION_REWRITER.get_context() + deploy_cfg = ctx.cfg + dtype = cls_scores[0].dtype + device = cls_scores[0].device + bbox_decoder = self.bbox_coder.decode + + assert len(cls_scores) == len(bbox_preds) + cfg = self.test_cfg if cfg is None else cfg + + num_imgs = cls_scores[0].shape[0] + featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] + + self.mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, dtype=dtype, device=device) + + flatten_priors = torch.cat(self.mlvl_priors) + + mlvl_strides = [ + flatten_priors.new_full( + (featmap_size[0] * featmap_size[1] * self.num_base_priors, ), + stride) + for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) + ] + flatten_stride = torch.cat(mlvl_strides) + + # flatten cls_scores, bbox_preds and objectness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) + for cls_score in cls_scores + ] + cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() + + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + for bbox_pred in bbox_preds + ] + flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + + if objectnesses is not None: + flatten_objectness = [ + objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) + for objectness in objectnesses + ] + flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() + cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1)) + + scores = cls_scores + bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds, + flatten_stride) + + # deal with key-poinsts + priors = torch.cat(self.mlvl_priors) + strides = [ + priors.new_full((featmap_size.numel() * self.num_base_priors, ), + stride) + for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) + ] + strides = torch.cat(strides) + kpt_preds = torch.cat([ + kpt_pred.permute(0, 2, 3, 1).reshape( + num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds + ], + dim=1) + flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides) + + vis_preds = torch.cat([ + vis_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_keypoints, + 1) for vis_pred in vis_preds + ], + dim=1).sigmoid() + + pred_kpts = torch.cat([flatten_decoded_kpts, vis_preds], dim=3) + + backend_config = get_backend_config(deploy_cfg) + if backend_config.type == Backend.TENSORRT.value: + # pad + bboxes = torch.cat( + [bboxes, + bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))], + dim=1) + scores = torch.cat( + [scores, scores.new_zeros((scores.shape[0], 1, 1))], dim=1) + pred_kpts = torch.cat([ + pred_kpts, + pred_kpts.new_zeros((pred_kpts.shape[0], 1, pred_kpts.shape[2], + pred_kpts.shape[3])) + ], + dim=1) + + # nms + post_params = get_post_processing_params(deploy_cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.get('pre_top_k', -1) + keep_top_k = post_params.get('keep_top_k', -1) + # do nms + _, _, nms_indices = multiclass_nms( + bboxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k, + output_index=True) + + batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1) + dets = torch.cat([bboxes, scores], dim=2) + dets = dets[batch_inds, nms_indices, ...] + pred_kpts = pred_kpts[batch_inds, nms_indices, ...] + + return dets, pred_kpts diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index 671fde41dc..88a6fe4501 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -186,7 +186,9 @@ def _select_nms_index(scores: torch.Tensor, boxes: torch.Tensor, nms_index: torch.Tensor, batch_size: int, - keep_top_k: int = -1): + keep_top_k: int = -1, + pre_inds: torch.Tensor = None, + output_index: bool = False): """Transform NMS output. Args: @@ -197,6 +199,10 @@ def _select_nms_index(scores: torch.Tensor, batch_size (int): Batch size of the input image. keep_top_k (int): Number of top K boxes to keep after nms. Defaults to -1. + pre_inds (Tensor): The pre-topk indices of boxes before nms. + Defaults to None. + return_index (bool): Whether to return indices of original bboxes. + Defaults to False. Returns: tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5] @@ -230,7 +236,13 @@ def _select_nms_index(scores: torch.Tensor, 1) batched_labels = torch.cat((batched_labels, batched_labels.new_zeros( (N, 1))), 1) - + if output_index and pre_inds is not None: + # batch all + pre_inds = pre_inds[batch_inds, box_inds] + pre_inds = pre_inds.unsqueeze(0).repeat(batch_size, 1) + pre_inds = pre_inds.where((batch_inds == batch_template.unsqueeze(1)), + pre_inds.new_zeros(1)) + pre_inds = torch.cat((pre_inds, pre_inds.new_zeros((N, 1))), 1) # sort is_use_topk = keep_top_k > 0 and \ (torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1]) @@ -243,7 +255,11 @@ def _select_nms_index(scores: torch.Tensor, device=topk_inds.device).view(-1, 1) batched_dets = batched_dets[topk_batch_inds, topk_inds, ...] batched_labels = batched_labels[topk_batch_inds, topk_inds, ...] - + if output_index: + if pre_inds is not None: + topk_inds = pre_inds[topk_batch_inds, topk_inds, ...] + topk_inds = topk_inds[:, :-1] + return batched_dets, batched_labels, topk_inds # slice and recover the tensor return batched_dets, batched_labels @@ -263,7 +279,6 @@ def _multiclass_nms(boxes: Tensor, shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes, 4). """ - assert not output_index, 'output_index is not supported on this backend.' if version.parse(torch.__version__) < version.parse('1.13.0'): max_output_boxes_per_class = torch.LongTensor( [max_output_boxes_per_class]) @@ -274,7 +289,8 @@ def _multiclass_nms(boxes: Tensor, if pre_top_k > 0: max_scores, _ = scores.max(-1) _, topk_inds = max_scores.topk(pre_top_k) - batch_inds = torch.arange(batch_size).view(-1, 1).long() + batch_inds = torch.arange( + batch_size, device=scores.device).view(-1, 1).long() boxes = boxes[batch_inds, topk_inds, :] scores = scores[batch_inds, topk_inds, :] @@ -283,10 +299,14 @@ def _multiclass_nms(boxes: Tensor, max_output_boxes_per_class, iou_threshold, score_threshold) - dets, labels = _select_nms_index( - scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) - - return dets, labels + return _select_nms_index( + scores, + boxes, + selected_indices, + batch_size, + keep_top_k=keep_top_k, + pre_inds=topk_inds, + output_index=output_index) def _multiclass_nms_single(boxes: Tensor, diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index 70ab9397bf..7971af1af3 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -43,6 +43,7 @@ class Codebase(AdvancedEnum): MMROTATE = 'mmrotate' MMACTION = 'mmaction' MMRAZOR = 'mmrazor' + MMYOLO = 'mmyolo' class IR(AdvancedEnum): diff --git a/tests/test_codebase/test_mmpose/test_mmpose_models.py b/tests/test_codebase/test_mmpose/test_mmpose_models.py index 3864c1fc76..93721e129f 100644 --- a/tests/test_codebase/test_mmpose/test_mmpose_models.py +++ b/tests/test_codebase/test_mmpose/test_mmpose_models.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import mmengine import pytest import torch @@ -188,3 +189,45 @@ def test_scale_forward(backend_type: Backend): deploy_cfg=deploy_cfg, run_with_backend=False) torch_assert_close(rewrite_outputs, model_outputs) + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_yolox_pose_head(backend_type: Backend): + try: + from models import yolox_pose_head # noqa: F401,F403 + except ImportError: + pytest.skip( + 'mmpose/projects/yolox-pose is not installed.', + allow_module_level=True) + deploy_cfg = mmengine.Config.fromfile( + 'configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py') + check_backend(backend_type, True) + model = yolox_pose_head.YOLOXPoseHead( + head_module=dict( + type='YOLOXPoseHeadModule', + num_classes=1, + in_channels=256, + feat_channels=256, + widen_factor=0.5, + stacked_convs=2, + num_keypoints=17, + featmap_strides=(8, 16, 32), + use_depthwise=False, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='SiLU', inplace=True), + )) + model.cpu().eval() + model_inputs = [ + torch.randn(2, 128, 80, 80), + torch.randn(2, 128, 40, 40), + torch.randn(2, 128, 20, 20) + ] + pytorch_output = model(model_inputs) + wrapped_model = WrapModel(model, 'forward') + rewrite_inputs = {'inputs': model_inputs} + rewrite_outputs, _ = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + run_with_backend=False, + deploy_cfg=deploy_cfg) + torch_assert_close(rewrite_outputs, pytorch_output) diff --git a/tests/test_codebase/test_mmpose/utils.py b/tests/test_codebase/test_mmpose/utils.py index 0d6eeecdd3..fe41bf3aed 100644 --- a/tests/test_codebase/test_mmpose/utils.py +++ b/tests/test_codebase/test_mmpose/utils.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import mmengine +import numpy import torch from mmengine.structures import InstanceData, PixelData @@ -16,7 +17,7 @@ def generate_datasample(img_size, heatmap_size=(64, 48)): input_size=(h, w), heatmap_size=heatmap_size) pred_instances = InstanceData() - pred_instances.bboxes = torch.rand((1, 4)).numpy() + pred_instances.bboxes = numpy.array([[0.0, 0.0, 1.0, 1.0]]) pred_instances.bbox_scales = torch.ones(1, 2).numpy() pred_instances.bbox_scores = torch.ones(1).numpy() pred_instances.bbox_centers = torch.ones(1, 2).numpy()