From 6edd8023def5e962a14218584989bae2690dd630 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Mon, 23 Oct 2023 09:49:25 +0800 Subject: [PATCH] [Fix] fix the onnx exportation for yoloxpose in mmpose (#2466) * fix the onnx exportation for yoloxpose * remove deprecated func * refine code * fix the rescaling process of top-down models * fix ut * add yoloxpose in regression test * fix comment * rebase & fix conflict --- .../codebase/mmpose/deploy/pose_detection.py | 2 +- .../mmpose/deploy/pose_detection_model.py | 36 +++-- .../mmpose/models/heads/yolox_pose_head.py | 145 ++++-------------- tests/regression/mmpose.yml | 10 ++ tests/test_codebase/test_mmpose/utils.py | 2 + 5 files changed, 64 insertions(+), 131 deletions(-) diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 86f2e4d09a..5e6b0c5c6f 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -60,7 +60,7 @@ def process_model_config( type='Normalize', mean=data_preprocessor.mean, std=data_preprocessor.std, - to_rgb=data_preprocessor.bgr_to_rgb)) + to_rgb=data_preprocessor.get('bgr_to_rgb', False))) test_pipeline.append(dict(type='ImageToTensor', keys=['img'])) test_pipeline.append( dict( diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py index be10a80e32..a2ec9a21ad 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection_model.py @@ -98,14 +98,15 @@ def forward(self, inputs = inputs.contiguous().to(self.device) batch_outputs = self.wrapper({self.input_name: inputs}) batch_outputs = self.wrapper.output_to_list(batch_outputs) - if self.model_cfg.model.type == 'YOLODetector': - return self.pack_yolox_pose_result(batch_outputs, data_samples) codebase_cfg = get_codebase_config(self.deploy_cfg) codec = self.model_cfg.codec if isinstance(codec, (list, tuple)): codec = codec[-1] - if codec.type == 'SimCCLabel': + + if codec.type == 'YOLOXPoseAnnotationProcessor': + return self.pack_yolox_pose_result(batch_outputs, data_samples) + elif codec.type == 'SimCCLabel': export_postprocess = codebase_cfg.get('export_postprocess', False) if export_postprocess: keypoints, scores = [_.cpu().numpy() for _ in batch_outputs] @@ -134,7 +135,7 @@ def pack_result(self, convert_coordinate (bool): Whether to convert keypoints coordinates to original image space. Default is True. Returns: - data_samples (List[BaseDataElement]): + data_samples (List[BaseDataElement]): updated data_samples with predictions. """ if isinstance(preds, tuple): @@ -153,11 +154,11 @@ def pack_result(self, # convert keypoint coordinates from input space to image space if convert_coordinate: input_size = data_sample.metainfo['input_size'] - bbox_centers = gt_instances.bbox_centers - bbox_scales = gt_instances.bbox_scales + input_center = data_sample.metainfo['input_center'] + input_scale = data_sample.metainfo['input_scale'] keypoints = pred_instances.keypoints - keypoints = keypoints / input_size * bbox_scales - keypoints += bbox_centers - 0.5 * bbox_scales + keypoints = keypoints / input_size * input_scale + keypoints += input_center - 0.5 * input_scale pred_instances.keypoints = keypoints pred_instances.bboxes = gt_instances.bboxes @@ -178,7 +179,7 @@ def pack_yolox_pose_result(self, preds: List[torch.Tensor], data_samples (List[BaseDataElement]): A list of meta info for image(s). Returns: - data_samples (List[BaseDataElement]): + data_samples (List[BaseDataElement]): updated data_samples with predictions. """ assert preds[0].shape[0] == len(data_samples) @@ -197,11 +198,20 @@ def pack_yolox_pose_result(self, preds: List[torch.Tensor], keypoint_scores = keypoint_scores[inds] pred_instances = InstanceData() + # rescale - scale_factor = data_sample.scale_factor - scale_factor = keypoints.new_tensor(scale_factor) - keypoints /= keypoints.new_tensor(scale_factor).reshape(1, 1, 2) - bboxes /= keypoints.new_tensor(scale_factor).repeat(1, 2) + input_size = data_sample.metainfo['input_size'] + input_center = data_sample.metainfo['input_center'] + input_scale = data_sample.metainfo['input_scale'] + + rescale = keypoints.new_tensor(input_scale) / keypoints.new_tensor( + input_size) + translation = keypoints.new_tensor( + input_center) - 0.5 * keypoints.new_tensor(input_scale) + + keypoints = keypoints * rescale.reshape( + 1, 1, 2) + translation.reshape(1, 1, 2) + bboxes = bboxes * rescale.repeat(1, 2) + translation.repeat(1, 2) pred_instances.bboxes = bboxes.cpu().numpy() pred_instances.bbox_scores = bbox_scores # the precision test requires keypoints to be np.ndarray diff --git a/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py b/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py index 1f061cc945..7553a9a6a0 100644 --- a/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py @@ -2,7 +2,6 @@ from typing import List, Optional, Tuple import torch -from mmengine.config import ConfigDict from torch import Tensor from mmdeploy.codebase.mmdet import get_post_processing_params @@ -11,18 +10,18 @@ from mmdeploy.utils import Backend, get_backend -@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.' - 'YOLOXPoseHead.predict') +@FUNCTION_REWRITER.register_rewriter( + func_name='mmpose.models.heads.hybrid_heads.' + 'yoloxpose_head.YOLOXPoseHead.forward') def predict(self, x: Tuple[Tensor], - batch_data_samples=None, - rescale: bool = True): + batch_data_samples: List = [], + test_cfg: Optional[dict] = None): """Get predictions and transform to bbox and keypoints results. Args: x (Tuple[Tensor]): The input tensor from upstream network. batch_data_samples: Batch image meta info. Defaults to None. - rescale: If True, return boxes in original image space. - Defaults to False. + test_cfg: The runtime config for testing process. Returns: Tuple[Tensor]: Predict bbox and keypoint results. @@ -33,73 +32,17 @@ def predict(self, Tensor, has shape (batch_size, num_instances, num_keypoints, 5), the last dimension 3 arrange as (x, y, score). """ - outs = self(x) - predictions = self.predict_by_feat( - *outs, batch_img_metas=batch_data_samples, rescale=rescale) - return predictions - - -@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.' - 'YOLOXPoseHead.predict_by_feat') -def yolox_pose_head__predict_by_feat( - self, - cls_scores: List[Tensor], - bbox_preds: List[Tensor], - objectnesses: Optional[List[Tensor]] = None, - kpt_preds: Optional[List[Tensor]] = None, - vis_preds: Optional[List[Tensor]] = None, - batch_img_metas: Optional[List[dict]] = None, - cfg: Optional[ConfigDict] = None, - rescale: bool = True, - with_nms: bool = True) -> Tuple[Tensor]: - """Transform a batch of output features extracted by the head into bbox and - keypoint results. - - In addition to the base class method, keypoint predictions are also - calculated in this method. - Args: - cls_scores (List[Tensor]): Classification scores for all - scale levels, each is a 4D-tensor, has shape - (batch_size, num_priors * num_classes, H, W). - bbox_preds (List[Tensor]): Box energies / deltas for all - scale levels, each is a 4D-tensor, has shape - (batch_size, num_priors * 4, H, W). - objectnesses (Optional[List[Tensor]]): Score factor for - all scale level, each is a 4D-tensor, has shape - (batch_size, 1, H, W). - kpt_preds (Optional[List[Tensor]]): Keypoints for all - scale levels, each is a 4D-tensor, has shape - (batch_size, num_keypoints * 2, H, W) - vis_preds (Optional[List[Tensor]]): Keypoints scores for - all scale levels, each is a 4D-tensor, has shape - (batch_size, num_keypoints, H, W) - batch_img_metas (Optional[List[dict]]): Batch image meta - info. Defaults to None. - cfg (Optional[ConfigDict]): Test / postprocessing - configuration, if None, test_cfg would be used. - Defaults to None. - rescale (bool): If True, return boxes in original image space. - Defaults to False. - with_nms (bool): If True, do nms before return boxes. - Defaults to True. - Returns: - Tuple[Tensor]: Predict bbox and keypoint results. - - dets (Tensor): Predict bboxes and scores, which is a 3D Tensor, - has shape (batch_size, num_instances, 5), the last dimension 5 - arrange as (x1, y1, x2, y2, score). - - pred_kpts (Tensor): Predict keypoints and scores, which is a 4D - Tensor, has shape (batch_size, num_instances, num_keypoints, 5), - the last dimension 3 arrange as (x, y, score). - """ + cls_scores, objectnesses, bbox_preds, kpt_offsets, \ + kpt_vis = self.head_module(x)[:5] + ctx = FUNCTION_REWRITER.get_context() deploy_cfg = ctx.cfg dtype = cls_scores[0].dtype device = cls_scores[0].device - bbox_decoder = self.bbox_coder.decode assert len(cls_scores) == len(bbox_preds) - cfg = self.test_cfg if cfg is None else cfg + cfg = self.test_cfg if test_cfg is None else test_cfg num_imgs = cls_scores[0].shape[0] featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] @@ -110,60 +53,27 @@ def yolox_pose_head__predict_by_feat( flatten_priors = torch.cat(self.mlvl_priors) mlvl_strides = [ - flatten_priors.new_full( - (featmap_size[0] * featmap_size[1] * self.num_base_priors, ), - stride) + flatten_priors.new_full((featmap_size.numel(), ), stride) for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) ] flatten_stride = torch.cat(mlvl_strides) # flatten cls_scores, bbox_preds and objectness - flatten_cls_scores = [ - cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes) - for cls_score in cls_scores - ] - cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() - - flatten_bbox_preds = [ - bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) - for bbox_pred in bbox_preds - ] - flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) - - if objectnesses is not None: - flatten_objectness = [ - objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) - for objectness in objectnesses - ] - flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() - cls_scores = cls_scores * (flatten_objectness.unsqueeze(-1)) - - scores = cls_scores - bboxes = bbox_decoder(flatten_priors[None], flatten_bbox_preds, - flatten_stride) - - # deal with key-poinsts - priors = torch.cat(self.mlvl_priors) - strides = [ - priors.new_full((featmap_size.numel() * self.num_base_priors, ), - stride) - for featmap_size, stride in zip(featmap_sizes, self.featmap_strides) - ] - strides = torch.cat(strides) - kpt_preds = torch.cat([ - kpt_pred.permute(0, 2, 3, 1).reshape( - num_imgs, -1, self.num_keypoints * 2) for kpt_pred in kpt_preds - ], - dim=1) - flatten_decoded_kpts = self.decode_pose(priors, kpt_preds, strides) - - vis_preds = torch.cat([ - vis_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_keypoints, - 1) for vis_pred in vis_preds - ], - dim=1).sigmoid() - - pred_kpts = torch.cat([flatten_decoded_kpts, vis_preds], dim=3) + flatten_cls_scores = self._flatten_predictions(cls_scores).sigmoid() + flatten_bbox_preds = self._flatten_predictions(bbox_preds) + flatten_objectness = self._flatten_predictions(objectnesses).sigmoid() + flatten_kpt_offsets = self._flatten_predictions(kpt_offsets) + flatten_kpt_vis = self._flatten_predictions(kpt_vis).sigmoid() + bboxes = self.decode_bbox(flatten_bbox_preds, flatten_priors, + flatten_stride) + flatten_decoded_kpts = self.decode_kpt_reg(flatten_kpt_offsets, + flatten_priors, flatten_stride) + + scores = flatten_cls_scores * flatten_objectness + + pred_kpts = torch.cat([flatten_decoded_kpts, + flatten_kpt_vis.unsqueeze(3)], + dim=3) backend = get_backend(deploy_cfg) if backend == Backend.TENSORRT: @@ -184,10 +94,11 @@ def yolox_pose_head__predict_by_feat( # nms post_params = get_post_processing_params(deploy_cfg) max_output_boxes_per_class = post_params.max_output_boxes_per_class - iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + iou_threshold = cfg.get('nms_thr', post_params.iou_threshold) score_threshold = cfg.get('score_thr', post_params.score_threshold) pre_top_k = post_params.get('pre_top_k', -1) keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + # do nms _, _, nms_indices = multiclass_nms( bboxes, diff --git a/tests/regression/mmpose.yml b/tests/regression/mmpose.yml index 44704c95e0..ad3b7b8744 100644 --- a/tests/regression/mmpose.yml +++ b/tests/regression/mmpose.yml @@ -140,3 +140,13 @@ models: sdk_config: configs/mmpose/pose-detection_simcc_sdk_static-256x192.py - convert_image: *convert_image deploy_config: configs/mmpose/pose-detection_simcc_ncnn_static-256x192.py + + - name: YOLOX-Pose + metafile: configs/body_2d_keypoint/yoloxpose/coco/yoloxpose_coco.yml + model_configs: + - configs/body_2d_keypoint/yoloxpose/coco/yoloxpose_s_8xb32-300e_coco-640.py + pipelines: + - convert_image: + input_img: *img_human_pose + test_img: *img_human_pose + deploy_config: configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py diff --git a/tests/test_codebase/test_mmpose/utils.py b/tests/test_codebase/test_mmpose/utils.py index fe41bf3aed..15ae030c70 100644 --- a/tests/test_codebase/test_mmpose/utils.py +++ b/tests/test_codebase/test_mmpose/utils.py @@ -15,6 +15,8 @@ def generate_datasample(img_size, heatmap_size=(64, 48)): img_shape=(h, w, 3), crop_size=(h, w), input_size=(h, w), + input_center=numpy.asarray((h / 2, w / 2)), + input_scale=numpy.asarray((h, w)), heatmap_size=heatmap_size) pred_instances = InstanceData() pred_instances.bboxes = numpy.array([[0.0, 0.0, 1.0, 1.0]])