From 81346b5c861ab788439f6ad63eb72ccf52f29e5d Mon Sep 17 00:00:00 2001 From: lrh Date: Mon, 9 Oct 2023 10:31:34 +0800 Subject: [PATCH] add condinst ut & update docs --- docs/en/04-supported-codebases/mmdet.md | 1 + docs/zh_cn/04-supported-codebases/mmdet.md | 2 + .../test_mmdet/test_mmdet_models.py | 218 ++++++++++++++++++ 3 files changed, 221 insertions(+) diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md index dba7b25d27..7220cc409b 100644 --- a/docs/en/04-supported-codebases/mmdet.md +++ b/docs/en/04-supported-codebases/mmdet.md @@ -218,6 +218,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter | [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | | [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | | [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | +| [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N | | [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | | [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | | [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md index 37bfe072a1..57083dfef4 100644 --- a/docs/zh_cn/04-supported-codebases/mmdet.md +++ b/docs/zh_cn/04-supported-codebases/mmdet.md @@ -10,6 +10,7 @@ - [后端模型推理](#后端模型推理) - [SDK 模型推理](#sdk-模型推理) - [模型支持列表](#模型支持列表) + - [注意事项](#注意事项) ______________________________________________________________________ @@ -220,6 +221,7 @@ cv2.imwrite('output_detection.png', img) | [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y | | [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | | [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | +| [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N | | [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | | [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | | [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index ca1c5c1255..232e50ac5e 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -2364,3 +2364,221 @@ def test_solov2_head_predict_by_feat(backend_type): atol=1e-05) else: assert rewrite_outputs is not None + + +def get_condinst_bbox_head(): + """condinst Bbox Head Config.""" + test_cfg = Config( + dict( + mask_thr=0.5, + max_per_img=100, + min_bbox_size=0, + nms=dict(iou_threshold=0.6, type='nms'), + nms_pre=1000, + score_thr=0.05)) + from mmdet.models.dense_heads import CondInstBboxHead + model = CondInstBboxHead( + center_sampling=True, + centerness_on_reg=True, + conv_bias=True, + dcn_on_last_conv=False, + feat_channels=256, + in_channels=256, + loss_bbox=dict(loss_weight=1.0, type='GIoULoss'), + loss_centerness=dict( + loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True), + loss_cls=dict( + alpha=0.25, + gamma=2.0, + loss_weight=1.0, + type='FocalLoss', + use_sigmoid=True), + norm_on_bbox=True, + num_classes=80, + num_params=169, + stacked_convs=4, + strides=[ + 8, + 16, + 32, + 64, + 128, + ], + test_cfg=test_cfg, + ) + + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_condinst_bbox_head_predict_by_feat(backend_type): + """Test predict_by_feat rewrite of condinst bbox head.""" + check_backend(backend_type) + condinst_bbox_head = get_condinst_bbox_head() + condinst_bbox_head.cpu().eval() + s = 128 + batch_img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + + output_names = ['dets', 'labels', 'param_preds', 'points', 'strides'] + deploy_cfg = Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + export_postprocess_mask=False)))) + + seed_everything(1234) + cls_scores = [ + torch.rand(1, condinst_bbox_head.num_classes, pow(2, i), pow(2, i)) + for i in range(5, 0, -1) + ] + seed_everything(5678) + bbox_preds = [ + torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(9101) + score_factors = [ + torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(1121) + param_preds = [ + torch.rand(1, condinst_bbox_head.num_params, pow(2, i), pow(2, i)) + for i in range(5, 0, -1) + ] + + # to get outputs of onnx model after rewrite + wrapped_model = WrapModel( + condinst_bbox_head, 'predict_by_feat', batch_img_metas=batch_img_metas) + rewrite_inputs = { + 'cls_scores': cls_scores, + 'bbox_preds': bbox_preds, + 'score_factors': score_factors, + 'param_preds': param_preds, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + dets = rewrite_outputs[0] + labels = rewrite_outputs[1] + param_preds = rewrite_outputs[2] + points = rewrite_outputs[3] + strides = rewrite_outputs[4] + assert dets.shape[-1] == 5 + assert labels is not None + assert param_preds.shape[-1] == condinst_bbox_head.num_params + assert points.shape[-1] == 2 + assert strides is not None + else: + assert rewrite_outputs is not None + + +def get_condinst_mask_head(): + """condinst Mask Head Config.""" + test_cfg = Config( + dict( + mask_thr=0.5, + max_per_img=100, + min_bbox_size=0, + nms=dict(iou_threshold=0.6, type='nms'), + nms_pre=1000, + score_thr=0.05)) + from mmdet.models.dense_heads import CondInstMaskHead + model = CondInstMaskHead( + mask_feature_head=dict( + end_level=2, + feat_channels=128, + in_channels=256, + mask_stride=8, + norm_cfg=dict(requires_grad=True, type='BN'), + num_stacked_convs=4, + out_channels=8, + start_level=0), + num_layers=3, + feat_channels=8, + mask_out_stride=4, + size_of_interest=8, + max_masks_to_train=300, + loss_mask=dict( + activate=True, + eps=5e-06, + loss_weight=1.0, + type='DiceLoss', + use_sigmoid=True), + test_cfg=test_cfg, + ) + + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_condinst_mask_head_forward(backend_type): + """Test predict_by_feat rewrite of condinst mask head.""" + check_backend(backend_type) + + output_names = ['mask_preds'] + deploy_cfg = Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict(type='mmdet', task='ObjectDetection'))) + + class TestCondInstMaskHeadModel(torch.nn.Module): + + def __init__(self, condinst_mask_head): + super(TestCondInstMaskHeadModel, self).__init__() + self.mask_head = condinst_mask_head + + def forward(self, x, param_preds, points, strides): + positive_infos = dict( + param_preds=param_preds, points=points, strides=strides) + return self.mask_head(x, positive_infos) + + mask_head = get_condinst_mask_head() + level = mask_head.mask_feature_head.end_level - \ + mask_head.mask_feature_head.start_level + 1 + + condinst_mask_head = TestCondInstMaskHeadModel(mask_head) + condinst_mask_head.cpu().eval() + + seed_everything(1234) + x = [torch.rand(1, 256, pow(2, i), pow(2, i)) for i in range(level, 0, -1)] + seed_everything(5678) + param_preds = torch.rand(1, 100, 169) + seed_everything(9101) + points = torch.rand(1, 100, 2) + seed_everything(1121) + strides = torch.rand(1, 100) + + # to get outputs of onnx model after rewrite + wrapped_model = WrapModel(condinst_mask_head, 'forward') + rewrite_inputs = { + 'x': x, + 'param_preds': param_preds, + 'points': points, + 'strides': strides + } + rewrite_outputs, _ = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + assert rewrite_outputs is not None