Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Aug 15, 2023
1 parent 6b93792 commit bef8e50
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 7 deletions.
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def forward(self,

bboxes = dets[:, :4]
scores = dets[:, 4]
print(img_metas[i])

# perform rescale
if rescale and 'scale_factor' in img_metas[i]:
scale_factor = img_metas[i]['scale_factor']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,7 @@ def _nms_with_mask_static(self,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k,
output_index=True)
# remove padded
dets = dets[:, :-1, :]
labels = labels[:, :-1]

batch_size = bboxes.shape[0]
batch_inds = torch.arange(batch_size, device=bboxes.device).view(-1, 1)
kernels = kernels[batch_inds, inds, :]
Expand Down
5 changes: 2 additions & 3 deletions mmdeploy/mmcv/ops/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _select_nms_index(scores: torch.Tensor,
pre_inds = pre_inds.unsqueeze(0).repeat(batch_size, 1)
pre_inds = pre_inds.where((batch_inds == batch_template.unsqueeze(1)),
pre_inds.new_zeros(1))
pre_inds = torch.cat((pre_inds, pre_inds.new_zeros((N, 1))), 1)
pre_inds = torch.cat((pre_inds, -pre_inds.new_ones((N, 1))), 1)
# sort
is_use_topk = keep_top_k > 0 and \
(torch.onnx.is_in_onnx_export() or keep_top_k < batched_dets.shape[1])
Expand All @@ -258,7 +258,6 @@ def _select_nms_index(scores: torch.Tensor,
if output_index:
if pre_inds is not None:
topk_inds = pre_inds[topk_batch_inds, topk_inds, ...]
topk_inds = topk_inds[:, :-1]
return batched_dets, batched_labels, topk_inds
# slice and recover the tensor
return batched_dets, batched_labels
Expand All @@ -285,7 +284,7 @@ def _multiclass_nms(boxes: Tensor,
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
batch_size = scores.shape[0]

topk_inds = None
if pre_top_k > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_top_k)
Expand Down

0 comments on commit bef8e50

Please sign in to comment.