Skip to content

Commit

Permalink
support batch inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Boomerl committed Sep 28, 2023
1 parent cdcb04e commit d580a59
Showing 1 changed file with 13 additions and 34 deletions.
47 changes: 13 additions & 34 deletions mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
from mmdet.models.utils import aligned_bilinear
from mmdet.utils import InstanceList
from mmengine.config import ConfigDict
from torch import Tensor

Expand Down Expand Up @@ -117,9 +116,6 @@ def condinst_mask_head__forward(self, x: tuple,
param_preds = positive_infos['param_preds']
points = positive_infos['points']
strides = positive_infos['strides']

Check warning on line 118 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L116-L118

Added lines #L116 - L118 were not covered by tests
param_preds = torch.stack(param_preds, dim=0)
points = torch.stack(points, dim=0)
strides = torch.stack(strides, dim=0)

batch_size = points.shape[0]
num_insts = points.shape[1]
Expand All @@ -135,52 +131,35 @@ def condinst_mask_head__forward(self, x: tuple,
rel_coordinates = (centers - locations).permute(0, 1, 3, 2).float()
rel_coordinates /= (strides[:, :, None, None] * self.size_of_interest)
rel_coords = rel_coordinates.reshape(batch_size, -1, 2, hw[0], hw[1])
mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=1)
mask_head_inputs = mask_head_inputs.reshape(batch_size, -1, hw[0], hw[1])
mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=2)

Check warning on line 134 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L130-L134

Added lines #L130 - L134 were not covered by tests
# TODO: change following code to support batch inference

weights, biases = _parse_dynamic_params(self, param_preds)
mask_preds = _dynamic_conv_forward(mask_feats, weights, biases)
mask_preds = _dynamic_conv_forward(mask_head_inputs, weights, biases)
mask_preds = mask_preds.reshape(batch_size, num_insts, hw[0], hw[1])
mask_preds = [
aligned_bilinear(
mask_preds[i].unsqueeze(0),
int(self.mask_feat_stride / self.mask_out_stride),
).squeeze(0) for i in range(batch_size)
]

mask_preds = aligned_bilinear(mask_preds,

Check warning on line 140 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L137-L140

Added lines #L137 - L140 were not covered by tests
int(self.mask_feat_stride / self.mask_out_stride))
return (mask_preds, )

Check warning on line 142 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L142

Added line #L142 was not covered by tests


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.CondInstMaskHead.predict_by_feat')
def condinst_mask_head__predict_by_feat(self,
mask_preds: List[Tensor],
results_list: InstanceList,
mask_preds: Tensor,
results_list: Dict[str, torch.Tensor],
batch_img_metas: List[dict],
rescale: bool = True,
**kwargs):
assert len(mask_preds) == len(results_list) == len(batch_img_metas)
cfg = self.test_cfg

Check warning on line 153 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L153

Added line #L153 was not covered by tests

dets = [results.dets.unsqueeze(0) for results in results_list]
labels = [results.labels.unsqueeze(0) for results in results_list]
img_hw = [img_meta['img_shape'][:2] for img_meta in batch_img_metas]

mask_preds = [
mask_preds[i].sigmoid().unsqueeze(0) for i in range(len(mask_preds))
]
mask_preds = [
aligned_bilinear(mask_preds[i], self.mask_out_stride)
for i in range(len(mask_preds))
]
mask_preds = [
mask_preds[i][:, :, :img_hw[i][0], :img_hw[i][1]]
for i in range(len(mask_preds))
]
dets = results_list['dets']
labels = results_list['labels']
img_hw = batch_img_metas[0]['img_shape'][:2]

Check warning on line 157 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L155-L157

Added lines #L155 - L157 were not covered by tests

masks = [mask_preds[i] > cfg.mask_thr for i in range(len(mask_preds))]
masks = [masks[i].float() for i in range(len(masks))]
mask_preds = mask_preds.sigmoid()
mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride)
mask_preds = mask_preds[:, :, :img_hw[0], :img_hw[1]]
masks = (mask_preds > cfg.mask_thr).float()

Check warning on line 162 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L159-L162

Added lines #L159 - L162 were not covered by tests

return dets, labels, masks

Check warning on line 164 in mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmdet/models/dense_heads/condinst_head.py#L164

Added line #L164 was not covered by tests

Expand Down

0 comments on commit d580a59

Please sign in to comment.