From fbac65e8af2ca2fae07ab0c9c4e6bb1a4ae45afc Mon Sep 17 00:00:00 2001 From: Lrh Date: Sat, 7 Oct 2023 15:46:29 +0800 Subject: [PATCH] add condinst head unit testing --- .../test_mmdet/test_mmdet_models.py | 252 ++++++++++++++++++ 1 file changed, 252 insertions(+) diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index ca1c5c1255..e8f19b2535 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -2364,3 +2364,255 @@ def test_solov2_head_predict_by_feat(backend_type): atol=1e-05) else: assert rewrite_outputs is not None + + +def get_condinst_bbox_head(): + """condinst Bbox Head Config.""" + test_cfg = Config( + dict( + mask_thr=0.5, + max_per_img=100, + min_bbox_size=0, + nms=dict(iou_threshold=0.6, type='nms'), + nms_pre=1000, + score_thr=0.05)) + from mmdet.models.dense_heads import CondInstBboxHead + model = CondInstBboxHead( + center_sampling=True, + centerness_on_reg=True, + conv_bias=True, + dcn_on_last_conv=False, + feat_channels=256, + in_channels=256, + loss_bbox=dict(loss_weight=1.0, type='GIoULoss'), + loss_centerness=dict( + loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True), + loss_cls=dict( + alpha=0.25, + gamma=2.0, + loss_weight=1.0, + type='FocalLoss', + use_sigmoid=True), + norm_on_bbox=True, + num_classes=80, + num_params=169, + stacked_convs=4, + strides=[ + 8, + 16, + 32, + 64, + 128, + ], + test_cfg=test_cfg, + ) + + model.requires_grad_(False) + return model + + +@pytest.mark.skipif( + reason='Only support GPU test', condition=not torch.cuda.is_available()) +@pytest.mark.parametrize('backend_type', + [Backend.ONNXRUNTIME, Backend.TENSORRT]) +def test_condinst_bbox_head_predict_by_feat(backend_type): + """Test predict_by_feat rewrite of condinst bbox head.""" + check_backend(backend_type) + condinst_bbox_head = get_condinst_bbox_head() + condinst_bbox_head.cpu().eval() + s = 128 + batch_img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + + output_names = ['dets', 'labels', 'param_preds', 'points', 'strides'] + deploy_cfg = Config( + dict( + backend_config=dict( + type=backend_type.value, + common_config=dict(max_workspace_size=1 << 32), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 320, 320], + opt_shape=[1, 3, 800, 1344], + max_shape=[1, 3, 1344, 1344]))) + ]), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + export_postprocess_mask=False)))) + + seed_everything(1234) + cls_scores = [ + torch.rand(1, condinst_bbox_head.num_classes, pow(2, i), pow(2, i)) + for i in range(5, 0, -1) + ] + seed_everything(5678) + bbox_preds = [ + torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(9101) + score_factors = [ + torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(1121) + param_preds = [ + torch.rand(1, condinst_bbox_head.num_params, pow(2, i), pow(2, i)) + for i in range(5, 0, -1) + ] + + # to get outputs of onnx/tensorrt model after rewrite + wrapped_model = WrapModel( + condinst_bbox_head, 'predict_by_feat', batch_img_metas=batch_img_metas) + rewrite_inputs = { + 'cls_scores': cls_scores, + 'bbox_preds': bbox_preds, + 'score_factors': score_factors, + 'param_preds': param_preds, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + dets = rewrite_outputs[0] + labels = rewrite_outputs[1] + param_preds = rewrite_outputs[2] + points = rewrite_outputs[3] + strides = rewrite_outputs[4] + assert dets.shape[-1] == 5 + assert labels is not None + assert param_preds.shape[-1] == condinst_bbox_head.num_params + assert points.shape[-1] == 2 + assert strides is not None + else: + assert rewrite_outputs is not None + + +def get_condinst_mask_head(): + """condinst Mask Head Config.""" + test_cfg = Config( + dict( + mask_thr=0.5, + max_per_img=100, + min_bbox_size=0, + nms=dict(iou_threshold=0.6, type='nms'), + nms_pre=1000, + score_thr=0.05)) + from mmdet.models.dense_heads import CondInstMaskHead + model = CondInstMaskHead( + mask_feature_head=dict( + end_level=2, + feat_channels=128, + in_channels=256, + mask_stride=8, + norm_cfg=dict(requires_grad=True, type='BN'), + num_stacked_convs=4, + out_channels=8, + start_level=0), + num_layers=3, + feat_channels=8, + mask_out_stride=4, + size_of_interest=8, + max_masks_to_train=300, + loss_mask=dict( + activate=True, + eps=5e-06, + loss_weight=1.0, + type='DiceLoss', + use_sigmoid=True), + test_cfg=test_cfg, + ) + + model.requires_grad_(False) + return model + + +@pytest.mark.skipif( + reason='Only support GPU test', condition=not torch.cuda.is_available()) +@pytest.mark.parametrize('backend_type', + [Backend.ONNXRUNTIME, Backend.TENSORRT]) +def test_condinst_mask_head_predict_by_feat(backend_type): + """Test predict_by_feat rewrite of condinst mask head.""" + check_backend(backend_type) + s = 128 + batch_img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + + output_names = ['dets', 'labels', 'masks'] + deploy_cfg = Config( + dict( + backend_config=dict( + type=backend_type.value, + common_config=dict(max_workspace_size=1 << 32), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 320, 320], + opt_shape=[1, 3, 800, 1344], + max_shape=[1, 3, 1344, 1344]))) + ]), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection'))) + + class TestCondInstMaskHeadModel(torch.nn.Module): + def __init__(self, condinst_mask_head): + super(TestCondInstMaskHeadModel, self).__init__() + self.mask_head = condinst_mask_head + + def predict_by_feat(self, mask_preds, det, label, batch_img_metas): + results = dict(dets=det, labels=label) + return self.mask_head.predict_by_feat(mask_preds, results, batch_img_metas) + + head = get_condinst_mask_head() + condinst_mask_head = TestCondInstMaskHeadModel(head) + condinst_mask_head.cpu().eval() + + seed_everything(1234) + mask_preds = torch.rand(1, 100, 200, 200) + seed_everything(5678) + dets = torch.rand(1, 100, 5) + labels = torch.rand(1, 100) + + # to get outputs of onnx/tensorrt model after rewrite + wrapped_model = WrapModel( + condinst_mask_head, 'predict_by_feat', batch_img_metas=batch_img_metas) + rewrite_inputs = { + 'mask_preds': mask_preds, + 'det': dets, + 'label': labels + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + dets = rewrite_outputs[0] + labels = rewrite_outputs[1] + masks = rewrite_outputs[2] + assert dets.shape[-1] == 5 + assert labels is not None + assert masks is not None + else: + assert rewrite_outputs is not None