From 2a481bfae9e33aee7631a4d5d3f4758a520e3d8f Mon Sep 17 00:00:00 2001 From: lrh Date: Fri, 10 Nov 2023 11:01:24 +0800 Subject: [PATCH] support sparseinst batch inference --- .../mmdet/deploy/object_detection_model.py | 2 +- .../mmdet/models/dense_heads/__init__.py | 1 + .../models/dense_heads/sparseinst_head.py | 63 +++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index c6a958e5eb..dab5b074b6 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -241,7 +241,7 @@ def postprocessing_results(self, masks = batch_masks[i] img_h, img_w = img_metas[i]['img_shape'][:2] ori_h, ori_w = img_metas[i]['ori_shape'][:2] - if model_type in ['RTMDet', 'CondInst']: + if model_type in ['RTMDet', 'CondInst', 'SparseInst']: export_postprocess_mask = True else: export_postprocess_mask = False diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index 062bc7de52..3bee17f449 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -11,5 +11,6 @@ from . import rtmdet_ins_head # noqa: F401,F403 from . import solo_head # noqa: F401,F403 from . import solov2_head # noqa: F401,F403 +from . import sparseinst_head # noqa: F401,F403 from . import yolo_head # noqa: F401,F403 from . import yolox_head # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py new file mode 100644 index 0000000000..6ddbb7744c --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch +import torch.nn.functional as F +from mmdet.models.utils import aligned_bilinear +from mmdet.structures import OptSampleList, SampleList +from mmengine.config import ConfigDict +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER + + +@torch.jit.script +def rescoring_mask(scores, mask_pred, masks): + mask_pred_ = mask_pred.float() + return scores * ((masks * mask_pred_).sum([2, 3]) / + (mask_pred_.sum([2, 3]) + 1e-6)) + + +@FUNCTION_REWRITER.register_rewriter( + 'projects.SparseInst.sparseinst.SparseInst.predict') +def sparseinst__predict( + self, + batch_inputs: Tensor, + batch_data_samples: List[dict], + rescale: bool = False, +): + """Rewrite `predict` of `SparseInst` for default backend.""" + max_shape = batch_inputs.shape[-2:] + x = self.extract_feat(batch_inputs) + output = self.decoder(x) + + pred_scores = output['pred_logits'].sigmoid() + pred_masks = output['pred_masks'].sigmoid() + pred_objectness = output['pred_scores'].sigmoid() + pred_scores = torch.sqrt(pred_scores * pred_objectness) + + # max/argmax + scores, labels = pred_scores.max(dim=-1) + # cls threshold + keep = scores > self.cls_threshold + scores = scores.where(keep, scores.new_zeros(1)) + labels = labels.where(keep, labels.new_zeros(1)) + keep = keep.unsqueeze(-1).unsqueeze(-1).expand_as(pred_masks) + pred_masks = pred_masks.where(keep, pred_masks.new_zeros(1)) + + img_meta = batch_data_samples[0].metainfo + # rescoring mask using maskness + scores = rescoring_mask(scores, + pred_masks > self.mask_threshold, + pred_masks) + h, w = img_meta['img_shape'][:2] + pred_masks = F.interpolate(pred_masks, + size=max_shape, + mode='bilinear', + align_corners=False)[:, :, :h, :w] + + bboxes = torch.zeros(scores.shape[0], scores.shape[1], 4) + dets = torch.cat([bboxes, scores.unsqueeze(-1)], dim=-1) + masks = (pred_masks > self.mask_threshold).float() + + return dets, labels, masks