Skip to content

Commit

Permalink
use mmcv.ops.nms
Browse files Browse the repository at this point in the history
  • Loading branch information
huayuan4396 committed Jun 19, 2023
1 parent 3266f9f commit 09443b8
Showing 1 changed file with 4 additions and 62 deletions.
66 changes: 4 additions & 62 deletions mmdeploy/codebase/mmpose/models/heads/yolox_pose_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,65 +9,7 @@

from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER


def yolox_pose_head_nms(boxes: Tensor,
scores: Tensor,
max_output_boxes_per_class: int = 1000,
iou_threshold: float = 0.5,
score_threshold: float = 0.05,
pre_top_k: int = -1,
keep_top_k: int = -1,
output_index: bool = True):
from packaging import version

from mmdeploy.mmcv.ops.nms import ONNXNMSop

if version.parse(torch.__version__) < version.parse('1.13.0'):
max_output_boxes_per_class = torch.LongTensor(
[max_output_boxes_per_class])
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)

# pre topk
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.squeeze(0).topk(pre_top_k)
boxes = boxes[:, topk_inds, :]
scores = scores[:, topk_inds, :]

scores = scores.permute(0, 2, 1)
selected_indices = ONNXNMSop.apply(boxes, scores,
max_output_boxes_per_class,
iou_threshold, score_threshold)

cls_inds = selected_indices[:, 1]
box_inds = selected_indices[:, 2]

scores = scores[:, cls_inds, box_inds].unsqueeze(2)
boxes = boxes[:, box_inds, ...]
dets = torch.cat([boxes, scores], dim=2)
labels = cls_inds.unsqueeze(0)

# pad
dets = torch.cat((dets, dets.new_zeros((1, 1, 5))), 1)
labels = torch.cat((labels, labels.new_zeros((1, 1))), 1)

# topk or sort
is_use_topk = keep_top_k > 0 and (torch.onnx.is_in_onnx_export()
or keep_top_k < dets.shape[1])
if is_use_topk:
_, topk_inds = dets[:, :, -1].topk(keep_top_k, dim=1)
else:
_, topk_inds = dets[:, :, -1].sort(dim=1, descending=True)
topk_inds = topk_inds.squeeze(0)
dets = dets[:, topk_inds, ...]
labels = labels[:, topk_inds, ...]

if output_index:
return dets, labels, box_inds
else:
return dets, labels
from mmdeploy.mmcv.ops.nms import multiclass_nms


@FUNCTION_REWRITER.register_rewriter(func_name='models.yolox_pose_head.'
Expand Down Expand Up @@ -187,12 +129,12 @@ def yolox_pose_head__predict_by_feat(
pred_kpts = flatten_decoded_kpts
pred_kpts_score = vis_preds
pred_score, pred_label = scores.max(2, keepdim=True)
nms_result = yolox_pose_head_nms(pred_bbox, pred_score,
nms_result = multiclass_nms(pred_bbox, pred_score,
max_output_boxes_per_class, iou_threshold,
score_threshold)
keep_indices_nms = [nms_result[2]]
score_threshold, 5000, 100, True)

for batch_idx in range(batch_size):
keep_indices_nms = [nms_result[2][batch_idx]]
bbox = pred_bbox[batch_idx][keep_indices_nms]
label = pred_label[batch_idx][keep_indices_nms]
score = pred_score[batch_idx][keep_indices_nms]
Expand Down

0 comments on commit 09443b8

Please sign in to comment.