Skip to content

Commit

Permalink
support htc (#2438)
Browse files Browse the repository at this point in the history
* support htc

* update mmdet.yml
  • Loading branch information
irexyc authored Oct 17, 2023
1 parent c4dc10d commit 1090fb6
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/en/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [HTC](https://github.com/open-mmlab/mmdetection/tree/main/configs/htc) | Instance Segmentation | Y | Y | N | ? | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ cv2.imwrite('output_detection.png', img)
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [HTC](https://github.com/open-mmlab/mmdetection/tree/main/configs/htc) | Instance Segmentation | Y | Y | N | ? | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/roi_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from . import bbox_head # noqa: F401,F403
from . import cascade_roi_head # noqa: F401,F403
from . import fcn_mask_head # noqa: F401,F403
from . import htc_roi_head # noqa: F401,F403
from . import single_level_roi_extractor # noqa: F401,F403
from . import standard_roi_head # noqa: F401,F403
6 changes: 3 additions & 3 deletions mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def cascade_roi_head__predict_bbox(self,
batch_img_metas: List[dict],
rpn_results_list: List[Tensor],
rcnn_test_cfg: ConfigType,
rescale: bool = False) -> List[Tensor]:
rescale: bool = False,
**kwargs) -> List[Tensor]:
"""Rewrite `predict_bbox` of `CascadeRoIHead` for default backend.
Args:
Expand Down Expand Up @@ -52,8 +53,7 @@ def cascade_roi_head__predict_bbox(self,
ms_scores = []
max_shape = batch_img_metas[0]['img_shape']
for i in range(self.num_stages):
bbox_results = self._bbox_forward(i, x, rois)

bbox_results = self._bbox_forward(i, x, rois, **kwargs)
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
# Recover the batch dimension
Expand Down
48 changes: 48 additions & 0 deletions mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.htc_roi_head.HybridTaskCascadeRoIHead.predict_mask'
)
def htc_roi_head__predict_mask(self,
x: Tuple[Tensor],
semantic_heat: Tensor,
batch_img_metas: List[dict],
results_list: List[Tensor],
rescale: bool = False) -> List[Tensor]:
dets, det_labels = results_list

batch_size = dets.size(0)
det_bboxes = dets[..., :4]
batch_index = torch.arange(
det_bboxes.size(0),
device=det_bboxes.device).float().view(-1, 1, 1).expand(
det_bboxes.size(0), det_bboxes.size(1), 1)
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
mask_rois = mask_rois.view(-1, 5)

mask_results = self._mask_forward(
stage=-1,
x=x,
rois=mask_rois,
semantic_feat=semantic_heat,
training=False)

mask_preds = mask_results['mask_preds'][0]
num_det = det_bboxes.shape[1]
segm_results = self.mask_head[-1].predict_by_feat(
mask_preds,
results_list,
batch_img_metas,
self.test_cfg,
rescale=rescale)
segm_results = segm_results.reshape(batch_size, num_det,
segm_results.shape[-2],
segm_results.shape[-1])
return dets, det_labels, segm_results
8 changes: 8 additions & 0 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,11 @@ models:
backend_test: *default_backend_test
- deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py
backend_test: *default_backend_test

- name: HTC
metafile: configs/htc/metafile.yml
model_configs:
- configs/htc/htc_r50_fpn_1x_coco.py
pipelines:
- *pipeline_seg_ort_dynamic_fp32
- *pipeline_seg_trt_dynamic_fp32

0 comments on commit 1090fb6

Please sign in to comment.