Skip to content

Commit

Permalink
add condinst head unit testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Boomerl committed Oct 7, 2023
1 parent d580a59 commit fbac65e
Showing 1 changed file with 252 additions and 0 deletions.
252 changes: 252 additions & 0 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,3 +2364,255 @@ def test_solov2_head_predict_by_feat(backend_type):
atol=1e-05)
else:
assert rewrite_outputs is not None


def get_condinst_bbox_head():
"""condinst Bbox Head Config."""
test_cfg = Config(
dict(
mask_thr=0.5,
max_per_img=100,
min_bbox_size=0,
nms=dict(iou_threshold=0.6, type='nms'),
nms_pre=1000,
score_thr=0.05))
from mmdet.models.dense_heads import CondInstBboxHead
model = CondInstBboxHead(
center_sampling=True,
centerness_on_reg=True,
conv_bias=True,
dcn_on_last_conv=False,
feat_channels=256,
in_channels=256,
loss_bbox=dict(loss_weight=1.0, type='GIoULoss'),
loss_centerness=dict(
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True),
loss_cls=dict(
alpha=0.25,
gamma=2.0,
loss_weight=1.0,
type='FocalLoss',
use_sigmoid=True),
norm_on_bbox=True,
num_classes=80,
num_params=169,
stacked_convs=4,
strides=[
8,
16,
32,
64,
128,
],
test_cfg=test_cfg,
)

model.requires_grad_(False)
return model


@pytest.mark.skipif(
reason='Only support GPU test', condition=not torch.cuda.is_available())
@pytest.mark.parametrize('backend_type',
[Backend.ONNXRUNTIME, Backend.TENSORRT])
def test_condinst_bbox_head_predict_by_feat(backend_type):
"""Test predict_by_feat rewrite of condinst bbox head."""
check_backend(backend_type)
condinst_bbox_head = get_condinst_bbox_head()
condinst_bbox_head.cpu().eval()
s = 128
batch_img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]

output_names = ['dets', 'labels', 'param_preds', 'points', 'strides']
deploy_cfg = Config(
dict(
backend_config=dict(
type=backend_type.value,
common_config=dict(max_workspace_size=1 << 32),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
opt_shape=[1, 3, 800, 1344],
max_shape=[1, 3, 1344, 1344])))
]),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
export_postprocess_mask=False))))

seed_everything(1234)
cls_scores = [
torch.rand(1, condinst_bbox_head.num_classes, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]
seed_everything(5678)
bbox_preds = [
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(9101)
score_factors = [
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(1121)
param_preds = [
torch.rand(1, condinst_bbox_head.num_params, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]

# to get outputs of onnx/tensorrt model after rewrite
wrapped_model = WrapModel(
condinst_bbox_head, 'predict_by_feat', batch_img_metas=batch_img_metas)
rewrite_inputs = {
'cls_scores': cls_scores,
'bbox_preds': bbox_preds,
'score_factors': score_factors,
'param_preds': param_preds,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

if is_backend_output:
dets = rewrite_outputs[0]
labels = rewrite_outputs[1]
param_preds = rewrite_outputs[2]
points = rewrite_outputs[3]
strides = rewrite_outputs[4]
assert dets.shape[-1] == 5
assert labels is not None
assert param_preds.shape[-1] == condinst_bbox_head.num_params
assert points.shape[-1] == 2
assert strides is not None
else:
assert rewrite_outputs is not None


def get_condinst_mask_head():
"""condinst Mask Head Config."""
test_cfg = Config(
dict(
mask_thr=0.5,
max_per_img=100,
min_bbox_size=0,
nms=dict(iou_threshold=0.6, type='nms'),
nms_pre=1000,
score_thr=0.05))
from mmdet.models.dense_heads import CondInstMaskHead
model = CondInstMaskHead(
mask_feature_head=dict(
end_level=2,
feat_channels=128,
in_channels=256,
mask_stride=8,
norm_cfg=dict(requires_grad=True, type='BN'),
num_stacked_convs=4,
out_channels=8,
start_level=0),
num_layers=3,
feat_channels=8,
mask_out_stride=4,
size_of_interest=8,
max_masks_to_train=300,
loss_mask=dict(
activate=True,
eps=5e-06,
loss_weight=1.0,
type='DiceLoss',
use_sigmoid=True),
test_cfg=test_cfg,
)

model.requires_grad_(False)
return model


@pytest.mark.skipif(
reason='Only support GPU test', condition=not torch.cuda.is_available())
@pytest.mark.parametrize('backend_type',
[Backend.ONNXRUNTIME, Backend.TENSORRT])
def test_condinst_mask_head_predict_by_feat(backend_type):
"""Test predict_by_feat rewrite of condinst mask head."""
check_backend(backend_type)
s = 128
batch_img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]

output_names = ['dets', 'labels', 'masks']
deploy_cfg = Config(
dict(
backend_config=dict(
type=backend_type.value,
common_config=dict(max_workspace_size=1 << 32),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
opt_shape=[1, 3, 800, 1344],
max_shape=[1, 3, 1344, 1344])))
]),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection')))

class TestCondInstMaskHeadModel(torch.nn.Module):
def __init__(self, condinst_mask_head):
super(TestCondInstMaskHeadModel, self).__init__()
self.mask_head = condinst_mask_head

def predict_by_feat(self, mask_preds, det, label, batch_img_metas):
results = dict(dets=det, labels=label)
return self.mask_head.predict_by_feat(mask_preds, results, batch_img_metas)

head = get_condinst_mask_head()
condinst_mask_head = TestCondInstMaskHeadModel(head)
condinst_mask_head.cpu().eval()

seed_everything(1234)
mask_preds = torch.rand(1, 100, 200, 200)
seed_everything(5678)
dets = torch.rand(1, 100, 5)
labels = torch.rand(1, 100)

# to get outputs of onnx/tensorrt model after rewrite
wrapped_model = WrapModel(
condinst_mask_head, 'predict_by_feat', batch_img_metas=batch_img_metas)
rewrite_inputs = {
'mask_preds': mask_preds,
'det': dets,
'label': labels
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)

if is_backend_output:
dets = rewrite_outputs[0]
labels = rewrite_outputs[1]
masks = rewrite_outputs[2]
assert dets.shape[-1] == 5
assert labels is not None
assert masks is not None
else:
assert rewrite_outputs is not None

0 comments on commit fbac65e

Please sign in to comment.