From 1090fb6f7b1be3b437208dcbdbc6986376b8bb47 Mon Sep 17 00:00:00 2001 From: Chen Xin Date: Tue, 17 Oct 2023 08:21:33 +0800 Subject: [PATCH] support htc (#2438) * support htc * update mmdet.yml --- docs/en/04-supported-codebases/mmdet.md | 1 + docs/zh_cn/04-supported-codebases/mmdet.md | 1 + .../mmdet/models/roi_heads/__init__.py | 1 + .../models/roi_heads/cascade_roi_head.py | 6 +-- .../mmdet/models/roi_heads/htc_roi_head.py | 48 +++++++++++++++++++ tests/regression/mmdet.yml | 8 ++++ 6 files changed, 62 insertions(+), 3 deletions(-) create mode 100644 mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md index 7220cc409b..16bbacb299 100644 --- a/docs/en/04-supported-codebases/mmdet.md +++ b/docs/en/04-supported-codebases/mmdet.md @@ -214,6 +214,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter | [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y | | [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | | [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [HTC](https://github.com/open-mmlab/mmdetection/tree/main/configs/htc) | Instance Segmentation | Y | Y | N | ? | Y | | [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | | [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | | [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md index 57083dfef4..c131f76698 100644 --- a/docs/zh_cn/04-supported-codebases/mmdet.md +++ b/docs/zh_cn/04-supported-codebases/mmdet.md @@ -217,6 +217,7 @@ cv2.imwrite('output_detection.png', img) | [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y | | [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y | | [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y | +| [HTC](https://github.com/open-mmlab/mmdetection/tree/main/configs/htc) | Instance Segmentation | Y | Y | N | ? | Y | | [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y | | [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | | [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/__init__.py b/mmdeploy/codebase/mmdet/models/roi_heads/__init__.py index f12a70dc6c..de0b68dfba 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/__init__.py @@ -2,5 +2,6 @@ from . import bbox_head # noqa: F401,F403 from . import cascade_roi_head # noqa: F401,F403 from . import fcn_mask_head # noqa: F401,F403 +from . import htc_roi_head # noqa: F401,F403 from . import single_level_roi_extractor # noqa: F401,F403 from . import standard_roi_head # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py index d556ae4ce2..3bed888890 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py @@ -16,7 +16,8 @@ def cascade_roi_head__predict_bbox(self, batch_img_metas: List[dict], rpn_results_list: List[Tensor], rcnn_test_cfg: ConfigType, - rescale: bool = False) -> List[Tensor]: + rescale: bool = False, + **kwargs) -> List[Tensor]: """Rewrite `predict_bbox` of `CascadeRoIHead` for default backend. Args: @@ -52,8 +53,7 @@ def cascade_roi_head__predict_bbox(self, ms_scores = [] max_shape = batch_img_metas[0]['img_shape'] for i in range(self.num_stages): - bbox_results = self._bbox_forward(i, x, rois) - + bbox_results = self._bbox_forward(i, x, rois, **kwargs) cls_score = bbox_results['cls_score'] bbox_pred = bbox_results['bbox_pred'] # Recover the batch dimension diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py new file mode 100644 index 0000000000..def9188441 --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/roi_heads/htc_roi_head.py @@ -0,0 +1,48 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet.models.roi_heads.htc_roi_head.HybridTaskCascadeRoIHead.predict_mask' +) +def htc_roi_head__predict_mask(self, + x: Tuple[Tensor], + semantic_heat: Tensor, + batch_img_metas: List[dict], + results_list: List[Tensor], + rescale: bool = False) -> List[Tensor]: + dets, det_labels = results_list + + batch_size = dets.size(0) + det_bboxes = dets[..., :4] + batch_index = torch.arange( + det_bboxes.size(0), + device=det_bboxes.device).float().view(-1, 1, 1).expand( + det_bboxes.size(0), det_bboxes.size(1), 1) + mask_rois = torch.cat([batch_index, det_bboxes], dim=-1) + mask_rois = mask_rois.view(-1, 5) + + mask_results = self._mask_forward( + stage=-1, + x=x, + rois=mask_rois, + semantic_feat=semantic_heat, + training=False) + + mask_preds = mask_results['mask_preds'][0] + num_det = det_bboxes.shape[1] + segm_results = self.mask_head[-1].predict_by_feat( + mask_preds, + results_list, + batch_img_metas, + self.test_cfg, + rescale=rescale) + segm_results = segm_results.reshape(batch_size, num_det, + segm_results.shape[-2], + segm_results.shape[-1]) + return dets, det_labels, segm_results diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index 679715afa2..c7dc1d9a73 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -456,3 +456,11 @@ models: backend_test: *default_backend_test - deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py backend_test: *default_backend_test + + - name: HTC + metafile: configs/htc/metafile.yml + model_configs: + - configs/htc/htc_r50_fpn_1x_coco.py + pipelines: + - *pipeline_seg_ort_dynamic_fp32 + - *pipeline_seg_trt_dynamic_fp32