Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support htc #2438

Merged
merged 2 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [HTC](https://github.com/open-mmlab/mmdetection/tree/main/configs/htc) | Instance Segmentation | Y | Y | N | ? | Y |
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
Expand Down
1 change: 1 addition & 0 deletions docs/zh_cn/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ cv2.imwrite('output_detection.png', img)
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [HTC](https://github.com/open-mmlab/mmdetection/tree/main/configs/htc) | Instance Segmentation | Y | Y | N | ? | Y |
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/roi_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from . import bbox_head # noqa: F401,F403
from . import cascade_roi_head # noqa: F401,F403
from . import fcn_mask_head # noqa: F401,F403
from . import htc_roi_head # noqa: F401,F403
from . import single_level_roi_extractor # noqa: F401,F403
from . import standard_roi_head # noqa: F401,F403
6 changes: 3 additions & 3 deletions mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def cascade_roi_head__predict_bbox(self,
batch_img_metas: List[dict],
rpn_results_list: List[Tensor],
rcnn_test_cfg: ConfigType,
rescale: bool = False) -> List[Tensor]:
rescale: bool = False,
**kwargs) -> List[Tensor]:
"""Rewrite `predict_bbox` of `CascadeRoIHead` for default backend.

Args:
Expand Down Expand Up @@ -52,8 +53,7 @@ def cascade_roi_head__predict_bbox(self,
ms_scores = []
max_shape = batch_img_metas[0]['img_shape']
for i in range(self.num_stages):
bbox_results = self._bbox_forward(i, x, rois)

bbox_results = self._bbox_forward(i, x, rois, **kwargs)
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
# Recover the batch dimension
Expand Down
48 changes: 48 additions & 0 deletions mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.roi_heads.htc_roi_head.HybridTaskCascadeRoIHead.predict_mask'
)
def htc_roi_head__predict_mask(self,
x: Tuple[Tensor],
semantic_heat: Tensor,
batch_img_metas: List[dict],
results_list: List[Tensor],
rescale: bool = False) -> List[Tensor]:
dets, det_labels = results_list

Check warning on line 19 in mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py#L19

Added line #L19 was not covered by tests

batch_size = dets.size(0)
det_bboxes = dets[..., :4]
batch_index = torch.arange(

Check warning on line 23 in mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py#L21-L23

Added lines #L21 - L23 were not covered by tests
det_bboxes.size(0),
device=det_bboxes.device).float().view(-1, 1, 1).expand(
det_bboxes.size(0), det_bboxes.size(1), 1)
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
mask_rois = mask_rois.view(-1, 5)

Check warning on line 28 in mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py#L27-L28

Added lines #L27 - L28 were not covered by tests

mask_results = self._mask_forward(

Check warning on line 30 in mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py#L30

Added line #L30 was not covered by tests
stage=-1,
x=x,
rois=mask_rois,
semantic_feat=semantic_heat,
training=False)

mask_preds = mask_results['mask_preds'][0]
num_det = det_bboxes.shape[1]
segm_results = self.mask_head[-1].predict_by_feat(

Check warning on line 39 in mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py#L37-L39

Added lines #L37 - L39 were not covered by tests
mask_preds,
results_list,
batch_img_metas,
self.test_cfg,
rescale=rescale)
segm_results = segm_results.reshape(batch_size, num_det,

Check warning on line 45 in mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py#L45

Added line #L45 was not covered by tests
segm_results.shape[-2],
segm_results.shape[-1])
return dets, det_labels, segm_results

Check warning on line 48 in mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py#L48

Added line #L48 was not covered by tests
8 changes: 8 additions & 0 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,11 @@ models:
backend_test: *default_backend_test
- deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py
backend_test: *default_backend_test

- name: HTC
metafile: configs/htc/metafile.yml
model_configs:
- configs/htc/htc_r50_fpn_1x_coco.py
pipelines:
- *pipeline_seg_ort_dynamic_fp32
- *pipeline_seg_trt_dynamic_fp32
Loading