Skip to content

Commit

Permalink
support sparseinst batch inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Boomerl committed Nov 10, 2023
1 parent e31f5a6 commit 2a481bf
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def postprocessing_results(self,
masks = batch_masks[i]
img_h, img_w = img_metas[i]['img_shape'][:2]
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
if model_type in ['RTMDet', 'CondInst']:
if model_type in ['RTMDet', 'CondInst', 'SparseInst']:
export_postprocess_mask = True
else:
export_postprocess_mask = False
Expand Down
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
from . import rtmdet_ins_head # noqa: F401,F403
from . import solo_head # noqa: F401,F403
from . import solov2_head # noqa: F401,F403
from . import sparseinst_head # noqa: F401,F403
from . import yolo_head # noqa: F401,F403
from . import yolox_head # noqa: F401,F403
63 changes: 63 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
from mmdet.models.utils import aligned_bilinear
from mmdet.structures import OptSampleList, SampleList
from mmengine.config import ConfigDict
from torch import Tensor

from mmdeploy.core import FUNCTION_REWRITER


@torch.jit.script
def rescoring_mask(scores, mask_pred, masks):
mask_pred_ = mask_pred.float()
return scores * ((masks * mask_pred_).sum([2, 3]) /
(mask_pred_.sum([2, 3]) + 1e-6))


@FUNCTION_REWRITER.register_rewriter(
'projects.SparseInst.sparseinst.SparseInst.predict')
def sparseinst__predict(
self,
batch_inputs: Tensor,
batch_data_samples: List[dict],
rescale: bool = False,
):
"""Rewrite `predict` of `SparseInst` for default backend."""
max_shape = batch_inputs.shape[-2:]
x = self.extract_feat(batch_inputs)
output = self.decoder(x)

pred_scores = output['pred_logits'].sigmoid()
pred_masks = output['pred_masks'].sigmoid()
pred_objectness = output['pred_scores'].sigmoid()
pred_scores = torch.sqrt(pred_scores * pred_objectness)

# max/argmax
scores, labels = pred_scores.max(dim=-1)
# cls threshold
keep = scores > self.cls_threshold
scores = scores.where(keep, scores.new_zeros(1))
labels = labels.where(keep, labels.new_zeros(1))
keep = keep.unsqueeze(-1).unsqueeze(-1).expand_as(pred_masks)
pred_masks = pred_masks.where(keep, pred_masks.new_zeros(1))

img_meta = batch_data_samples[0].metainfo
# rescoring mask using maskness
scores = rescoring_mask(scores,
pred_masks > self.mask_threshold,
pred_masks)
h, w = img_meta['img_shape'][:2]
pred_masks = F.interpolate(pred_masks,
size=max_shape,
mode='bilinear',
align_corners=False)[:, :, :h, :w]

bboxes = torch.zeros(scores.shape[0], scores.shape[1], 4)
dets = torch.cat([bboxes, scores.unsqueeze(-1)], dim=-1)
masks = (pred_masks > self.mask_threshold).float()

return dets, labels, masks

0 comments on commit 2a481bf

Please sign in to comment.