From a52b49f52cfea9e1a6f5f46f6129bfd891f8aae8 Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 27 Sep 2023 19:12:37 +0800 Subject: [PATCH] refine code --- .../codebase/mmpose/models/heads/yolox_pose_head.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py b/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py index ea31ccc4d1..7553a9a6a0 100644 --- a/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py +++ b/mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Tuple +from typing import List, Optional, Tuple import torch from torch import Tensor @@ -16,13 +16,12 @@ def predict(self, x: Tuple[Tensor], batch_data_samples: List = [], - test_cfg: dict = dict()): + test_cfg: Optional[dict] = None): """Get predictions and transform to bbox and keypoints results. Args: x (Tuple[Tensor]): The input tensor from upstream network. batch_data_samples: Batch image meta info. Defaults to None. - rescale: If True, return boxes in original image space. - Defaults to False. + test_cfg: The runtime config for testing process. Returns: Tuple[Tensor]: Predict bbox and keypoint results. @@ -43,7 +42,7 @@ def predict(self, device = cls_scores[0].device assert len(cls_scores) == len(bbox_preds) - cfg = self.test_cfg + cfg = self.test_cfg if test_cfg is None else test_cfg num_imgs = cls_scores[0].shape[0] featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] @@ -113,7 +112,7 @@ def predict(self, batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1) dets = torch.cat([bboxes, scores], dim=2) - dets = dets[batch_inds, nms_indices, ...] # [1, n, 5] - pred_kpts = pred_kpts[batch_inds, nms_indices, ...] # [1, n, 17, 3] + dets = dets[batch_inds, nms_indices, ...] + pred_kpts = pred_kpts[batch_inds, nms_indices, ...] return dets, pred_kpts