Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeCamp2023-555 #2469

Merged
merged 12 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def postprocessing_results(self,
masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
if model_type == 'RTMDet':
if model_type in ['RTMDet', 'CondInst']:
export_postprocess_mask = True
else:
export_postprocess_mask = False
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from . import solov2_head # noqa: F401,F403
from . import yolo_head # noqa: F401,F403
from . import yolox_head # noqa: F401,F403
from . import condinst_head # noqa: F401,F403
203 changes: 203 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import torch
from mmdet.models.utils import aligned_bilinear
from mmengine.config import ConfigDict
from torch import Tensor

from mmdeploy.codebase.mmdet.deploy import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmcv.ops.nms import multiclass_nms


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstBboxHead.predict_by_feat')
def condinst_bbox_head__predict_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
param_preds: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True,
):
ctx = FUNCTION_REWRITER.get_context()
deploy_cfg = ctx.cfg

Check warning on line 28 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L27-L28

Added lines #L27 - L28 were not covered by tests

assert len(cls_scores) == len(bbox_preds)
device = bbox_preds[0].device
cfg = self.test_cfg if cfg is None else cfg
batch_size = bbox_preds[0].shape[0]

Check warning on line 33 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L30-L33

Added lines #L30 - L33 were not covered by tests
featmap_sizes = [cls_score.shape[-2:] for cls_score in cls_scores]

all_level_points_strides = self.prior_generator.grid_priors(

Check warning on line 36 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L36

Added line #L36 was not covered by tests
featmap_sizes, device=device, with_stride=True)
all_level_points = [i[:, :2] for i in all_level_points_strides]
all_level_strides = [i[:, 2] for i in all_level_points_strides]

flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
for bbox_pred in bbox_preds
]
flatten_score_factors = [
score_factor.permute(0, 2, 3, 1).reshape(batch_size, -1, 1)
for score_factor in score_factors
]
flatten_param_preds = [
param_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_params)
for param_pred in param_preds
]
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_score_factors = torch.cat(flatten_score_factors, dim=1).sigmoid()
flatten_param_preds = torch.cat(flatten_param_preds, dim=1)

Check warning on line 61 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L58-L61

Added lines #L58 - L61 were not covered by tests

points = torch.cat(all_level_points)
strides = torch.cat(all_level_strides)
tl_x = points[..., 0] - flatten_bbox_preds[..., 0]
tl_y = points[..., 1] - flatten_bbox_preds[..., 1]
br_x = points[..., 0] + flatten_bbox_preds[..., 2]
br_y = points[..., 1] + flatten_bbox_preds[..., 3]

Check warning on line 68 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L63-L68

Added lines #L63 - L68 were not covered by tests

bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
scores = flatten_cls_scores
score_factors = flatten_score_factors
param_preds = flatten_param_preds
scores = scores * score_factors

Check warning on line 74 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L70-L74

Added lines #L70 - L74 were not covered by tests

# get post processing config
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)

Check warning on line 82 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L77-L82

Added lines #L77 - L82 were not covered by tests

dets, labels, inds = multiclass_nms(

Check warning on line 84 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L84

Added line #L84 was not covered by tests
bboxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k,
output_index=True,
)

batch_inds = torch.arange(batch_size, device=bboxes.device).view(-1, 1)
points = points.unsqueeze(0).repeat(batch_size, 1, 1)
strides = strides.unsqueeze(0).repeat(batch_size, 1)
param_preds = param_preds[batch_inds, inds, :]
points = points[batch_inds, inds, :]
strides = strides[batch_inds, inds]
results = dict(

Check warning on line 101 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L95-L101

Added lines #L95 - L101 were not covered by tests
dets=dets,
labels=labels,
param_preds=param_preds,
points=points,
strides=strides)
return results

Check warning on line 107 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L107

Added line #L107 was not covered by tests


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstMaskHead.forward')
def condinst_mask_head__forward(self, x: tuple,
positive_infos: Dict[str, torch.Tensor]):
mask_feats = self.mask_feature_head(x)

Check warning on line 114 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L114

Added line #L114 was not covered by tests

param_preds = positive_infos['param_preds']
points = positive_infos['points']
strides = positive_infos['strides']

Check warning on line 118 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L116-L118

Added lines #L116 - L118 were not covered by tests

batch_size = points.shape[0]
num_insts = points.shape[1]
hw = mask_feats.size()[-2:]
mask_feats = mask_feats.unsqueeze(1).repeat(1, num_insts, 1, 1, 1)

Check warning on line 123 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L120-L123

Added lines #L120 - L123 were not covered by tests

points = points.reshape(-1, 1, 2).unsqueeze(0)
locations = self.prior_generator.single_level_grid_priors(

Check warning on line 126 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L125-L126

Added lines #L125 - L126 were not covered by tests
hw, level_idx=0, device=mask_feats.device)
locations = locations.unsqueeze(0).repeat(batch_size, 1,

Check warning on line 128 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L128

Added line #L128 was not covered by tests
1).reshape(batch_size, 1, -1, 2)
centers = points.reshape(batch_size, -1, 1, 2)
rel_coordinates = (centers - locations).permute(0, 1, 3, 2).float()
rel_coordinates /= (strides[:, :, None, None] * self.size_of_interest)
rel_coords = rel_coordinates.reshape(batch_size, -1, 2, hw[0], hw[1])
mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=2)

Check warning on line 134 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L130-L134

Added lines #L130 - L134 were not covered by tests
# TODO: change following code to support batch inference
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved

weights, biases = _parse_dynamic_params(self, param_preds)
mask_preds = _dynamic_conv_forward(mask_head_inputs, weights, biases)
mask_preds = mask_preds.reshape(batch_size, num_insts, hw[0], hw[1])
mask_preds = aligned_bilinear(mask_preds,

Check warning on line 140 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L137-L140

Added lines #L137 - L140 were not covered by tests
int(self.mask_feat_stride / self.mask_out_stride))
return (mask_preds, )

Check warning on line 142 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L142

Added line #L142 was not covered by tests


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstMaskHead.predict_by_feat')
def condinst_mask_head__predict_by_feat(self,
mask_preds: Tensor,
results_list: Dict[str, torch.Tensor],
batch_img_metas: List[dict],
rescale: bool = True,
**kwargs):
cfg = self.test_cfg

Check warning on line 153 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L153

Added line #L153 was not covered by tests

dets = results_list['dets']
labels = results_list['labels']
img_hw = batch_img_metas[0]['img_shape'][:2]

Check warning on line 157 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L155-L157

Added lines #L155 - L157 were not covered by tests

mask_preds = mask_preds.sigmoid()
mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride)
mask_preds = mask_preds[:, :, :img_hw[0], :img_hw[1]]
masks = (mask_preds > cfg.mask_thr).float()

Check warning on line 162 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L159-L162

Added lines #L159 - L162 were not covered by tests

return dets, labels, masks

Check warning on line 164 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L164

Added line #L164 was not covered by tests


def _parse_dynamic_params(self, params: Tensor):
"""parse the dynamic params for dynamic conv."""
batch_size = params.shape[0]
num_insts = params.shape[1]
params = params.permute(1, 0, 2)
params_splits = list(

Check warning on line 172 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L169-L172

Added lines #L169 - L172 were not covered by tests
torch.split_with_sizes(
params, self.weight_nums + self.bias_nums, dim=2))

weight_splits = params_splits[:self.num_layers]
bias_splits = params_splits[self.num_layers:]

Check warning on line 177 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L176-L177

Added lines #L176 - L177 were not covered by tests

for idx in range(self.num_layers):
if idx < self.num_layers - 1:
weight_splits[idx] = weight_splits[idx].reshape(

Check warning on line 181 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L181

Added line #L181 was not covered by tests
batch_size, num_insts, self.in_channels, -1)
else:
weight_splits[idx] = weight_splits[idx].reshape(

Check warning on line 184 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L184

Added line #L184 was not covered by tests
batch_size, num_insts, 1, -1)

return weight_splits, bias_splits

Check warning on line 187 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L187

Added line #L187 was not covered by tests


def _dynamic_conv_forward(features: Tensor, weights: List[Tensor],
biases: List[Tensor]):
"""dynamic forward, each layer follow a relu."""
n_layers = len(weights)
x = features.flatten(0, 1).flatten(2)

Check warning on line 194 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L193-L194

Added lines #L193 - L194 were not covered by tests
for i, (w, b) in enumerate(zip(weights, biases)):
# replace dynamic conv with bmm
w = w.flatten(0, 1)
b = b.flatten(0, 1).unsqueeze(2)
x = torch.bmm(w, x)
x = x + b

Check warning on line 200 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L197-L200

Added lines #L197 - L200 were not covered by tests
if i < n_layers - 1:
x = x.clamp_(min=0)
return x

Check warning on line 203 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L202-L203

Added lines #L202 - L203 were not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,30 @@
'instance_segmentor_forward',
inputs=['input'],
outputs=['dets', 'labels', 'masks'])
def __forward_impl_instance_seg(self, batch_inputs, data_samples, **kwargs):
def __forward_impl_instance_seg(self,
batch_inputs,
data_samples,
rescale=True,
**kwargs):
"""Rewrite and adding mark for `forward`.

Encapsulate this function for rewriting `forward` of BaseDetector.
1. Add mark for BaseDetector.
2. Support both dynamic and static export to onnx.
"""
x = self.extract_feat(batch_inputs)
mask_outs = self.mask_head.predict(x, data_samples, rescale=False)
if self.with_bbox:
# the bbox branch does not need to be scaled to the original
# image scale, because the mask branch will scale both bbox
# and mask at the same time.
bbox_rescale = rescale if not self.with_mask else False
results_list = self.bbox_head.predict(

Check warning on line 32 in mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py#L31-L32

Added lines #L31 - L32 were not covered by tests
x, data_samples, rescale=bbox_rescale)
else:
results_list = None

Check warning on line 35 in mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py#L35

Added line #L35 was not covered by tests

mask_outs = self.mask_head.predict(

Check warning on line 37 in mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py#L37

Added line #L37 was not covered by tests
x, data_samples, rescale=rescale, results_list=results_list)
return mask_outs


Expand Down
2 changes: 2 additions & 0 deletions mmdeploy/pytorch/functions/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
if isinstance(*size, tuple):
return origin_func(input.unsqueeze(0), *([1] + list(*size))).squeeze(0)

Check warning on line 23 in mmdeploy/pytorch/functions/repeat.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/pytorch/functions/repeat.py#L23

Added line #L23 was not covered by tests
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
else:
return origin_func(input, *size)
10 changes: 10 additions & 0 deletions tests/regression/mmdet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,13 @@ models:
pipelines:
- *pipeline_ort_dynamic_fp32
- *pipeline_trt_dynamic_fp32

- name: CondInst
metafile: /configs/condinst/metafile.yml
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
model_configs:
- configs/condinst/condinst_r50_fpn_ms-poly-90k_coco_instance.py
pipelines:
- deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime_dynamic.py
backend_test: *default_backend_test
- deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py
backend_test: *default_backend_test
Loading
Loading