Skip to content

Commit

Permalink
add condinst ut & update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Boomerl committed Oct 9, 2023
1 parent 4c376d9 commit 81346b5
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/en/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
| [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N |
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
Expand Down
2 changes: 2 additions & 0 deletions docs/zh_cn/04-supported-codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- [后端模型推理](#后端模型推理)
- [SDK 模型推理](#sdk-模型推理)
- [模型支持列表](#模型支持列表)
- [注意事项](#注意事项)

______________________________________________________________________

Expand Down Expand Up @@ -220,6 +221,7 @@ cv2.imwrite('output_detection.png', img)
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |
| [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y |
| [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N |
| [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N |
| [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N |
| [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N |
Expand Down
218 changes: 218 additions & 0 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,3 +2364,221 @@ def test_solov2_head_predict_by_feat(backend_type):
atol=1e-05)
else:
assert rewrite_outputs is not None


def get_condinst_bbox_head():
"""condinst Bbox Head Config."""
test_cfg = Config(
dict(
mask_thr=0.5,
max_per_img=100,
min_bbox_size=0,
nms=dict(iou_threshold=0.6, type='nms'),
nms_pre=1000,
score_thr=0.05))
from mmdet.models.dense_heads import CondInstBboxHead
model = CondInstBboxHead(
center_sampling=True,
centerness_on_reg=True,
conv_bias=True,
dcn_on_last_conv=False,
feat_channels=256,
in_channels=256,
loss_bbox=dict(loss_weight=1.0, type='GIoULoss'),
loss_centerness=dict(
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True),
loss_cls=dict(
alpha=0.25,
gamma=2.0,
loss_weight=1.0,
type='FocalLoss',
use_sigmoid=True),
norm_on_bbox=True,
num_classes=80,
num_params=169,
stacked_convs=4,
strides=[
8,
16,
32,
64,
128,
],
test_cfg=test_cfg,
)

model.requires_grad_(False)
return model


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_condinst_bbox_head_predict_by_feat(backend_type):
"""Test predict_by_feat rewrite of condinst bbox head."""
check_backend(backend_type)
condinst_bbox_head = get_condinst_bbox_head()
condinst_bbox_head.cpu().eval()
s = 128
batch_img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]

output_names = ['dets', 'labels', 'param_preds', 'points', 'strides']
deploy_cfg = Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
export_postprocess_mask=False))))

seed_everything(1234)
cls_scores = [
torch.rand(1, condinst_bbox_head.num_classes, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]
seed_everything(5678)
bbox_preds = [
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(9101)
score_factors = [
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(1121)
param_preds = [
torch.rand(1, condinst_bbox_head.num_params, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]

# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(
condinst_bbox_head, 'predict_by_feat', batch_img_metas=batch_img_metas)
rewrite_inputs = {
'cls_scores': cls_scores,
'bbox_preds': bbox_preds,
'score_factors': score_factors,
'param_preds': param_preds,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

if is_backend_output:
dets = rewrite_outputs[0]
labels = rewrite_outputs[1]
param_preds = rewrite_outputs[2]
points = rewrite_outputs[3]
strides = rewrite_outputs[4]
assert dets.shape[-1] == 5
assert labels is not None
assert param_preds.shape[-1] == condinst_bbox_head.num_params
assert points.shape[-1] == 2
assert strides is not None
else:
assert rewrite_outputs is not None


def get_condinst_mask_head():
"""condinst Mask Head Config."""
test_cfg = Config(
dict(
mask_thr=0.5,
max_per_img=100,
min_bbox_size=0,
nms=dict(iou_threshold=0.6, type='nms'),
nms_pre=1000,
score_thr=0.05))
from mmdet.models.dense_heads import CondInstMaskHead
model = CondInstMaskHead(
mask_feature_head=dict(
end_level=2,
feat_channels=128,
in_channels=256,
mask_stride=8,
norm_cfg=dict(requires_grad=True, type='BN'),
num_stacked_convs=4,
out_channels=8,
start_level=0),
num_layers=3,
feat_channels=8,
mask_out_stride=4,
size_of_interest=8,
max_masks_to_train=300,
loss_mask=dict(
activate=True,
eps=5e-06,
loss_weight=1.0,
type='DiceLoss',
use_sigmoid=True),
test_cfg=test_cfg,
)

model.requires_grad_(False)
return model


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_condinst_mask_head_forward(backend_type):
"""Test predict_by_feat rewrite of condinst mask head."""
check_backend(backend_type)

output_names = ['mask_preds']
deploy_cfg = Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(type='mmdet', task='ObjectDetection')))

class TestCondInstMaskHeadModel(torch.nn.Module):

def __init__(self, condinst_mask_head):
super(TestCondInstMaskHeadModel, self).__init__()
self.mask_head = condinst_mask_head

def forward(self, x, param_preds, points, strides):
positive_infos = dict(
param_preds=param_preds, points=points, strides=strides)
return self.mask_head(x, positive_infos)

mask_head = get_condinst_mask_head()
level = mask_head.mask_feature_head.end_level - \
mask_head.mask_feature_head.start_level + 1

condinst_mask_head = TestCondInstMaskHeadModel(mask_head)
condinst_mask_head.cpu().eval()

seed_everything(1234)
x = [torch.rand(1, 256, pow(2, i), pow(2, i)) for i in range(level, 0, -1)]
seed_everything(5678)
param_preds = torch.rand(1, 100, 169)
seed_everything(9101)
points = torch.rand(1, 100, 2)
seed_everything(1121)
strides = torch.rand(1, 100)

# to get outputs of onnx model after rewrite
wrapped_model = WrapModel(condinst_mask_head, 'forward')
rewrite_inputs = {
'x': x,
'param_preds': param_preds,
'points': points,
'strides': strides
}
rewrite_outputs, _ = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

assert rewrite_outputs is not None

0 comments on commit 81346b5

Please sign in to comment.