Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Sep 8, 2023
1 parent a3c2fae commit 50e9014
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 5 deletions.
4 changes: 4 additions & 0 deletions configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@
0: 'batch'
}
})

codebase_config = dict(
export_postprocess=False # do not export get_simcc_maximum
)
5 changes: 5 additions & 0 deletions mmdeploy/codebase/mmpose/codecs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .post_processing import get_simcc_maximum

__all__ = ['get_simcc_maximum']
33 changes: 33 additions & 0 deletions mmdeploy/codebase/mmpose/codecs/post_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def get_simcc_maximum(simcc_x: torch.Tensor,
simcc_y: torch.Tensor) -> torch.Tensor:
"""Get maximum response location and value from simcc representations.
rewrite to support `torch.Tensor` input type.
Args:
simcc_x (torch.Tensor): x-axis SimCC in shape (N, K, Wx)
simcc_y (torch.Tensor): y-axis SimCC in shape (N, K, Wy)
Returns:
tuple:
- locs (torch.Tensor): locations of maximum heatmap responses in shape
(N, K, 2)
- vals (torch.Tensor): values of maximum heatmap responses in shape
(N, K)
"""
N, K, _ = simcc_x.shape
simcc_x = simcc_x.flatten(0, 1)
simcc_y = simcc_y.flatten(0, 1)
x_locs = simcc_x.argmax(dim=1, keepdim=True)
y_locs = simcc_y.argmax(dim=1, keepdim=True)
locs = torch.cat((x_locs, y_locs), dim=1).to(torch.float32)
max_val_x, _ = simcc_x.max(dim=1, keepdim=True)
max_val_y, _ = simcc_y.max(dim=1, keepdim=True)
vals, _ = torch.cat([max_val_x, max_val_y], dim=1).min(dim=1)
locs = locs.reshape(N, K, 2)
vals = vals.reshape(N, K)
return locs, vals

Check warning on line 33 in mmdeploy/codebase/mmpose/codecs/post_processing.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/codecs/post_processing.py#L22-L33

Added lines #L22 - L33 were not covered by tests
6 changes: 5 additions & 1 deletion mmdeploy/codebase/mmpose/deploy/pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from mmengine.registry import Registry

from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
from mmdeploy.utils import (Codebase, Task, get_codebase_config,
get_input_shape, get_root_logger)


def process_model_config(
Expand Down Expand Up @@ -362,6 +363,9 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
params['post_process'] = 'megvii'
params['modulate_kernel'] = self.model_cfg.kernel_sizes[-1]
elif codec.type == 'SimCCLabel':
export_postprocess = get_codebase_config(self.deploy_cfg).get(

Check warning on line 366 in mmdeploy/codebase/mmpose/deploy/pose_detection.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/deploy/pose_detection.py#L366

Added line #L366 was not covered by tests
'export_postprocess', False)
params['export_postprocess'] = export_postprocess

Check warning on line 368 in mmdeploy/codebase/mmpose/deploy/pose_detection.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/deploy/pose_detection.py#L368

Added line #L368 was not covered by tests
component = 'SimCCLabelDecode'
elif codec.type == 'RegressionLabel':
component = 'DeepposeRegressionHeadDecode'
Expand Down
12 changes: 10 additions & 2 deletions mmdeploy/codebase/mmpose/deploy/pose_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,20 @@ def forward(self,
if self.model_cfg.model.type == 'YOLODetector':
return self.pack_yolox_pose_result(batch_outputs, data_samples)

codebase_cfg = get_codebase_config(self.deploy_cfg)
codec = self.model_cfg.codec
if isinstance(codec, (list, tuple)):
codec = codec[-1]
if codec.type == 'SimCCLabel':
batch_pred_x, batch_pred_y = batch_outputs
preds = self.head.decode((batch_pred_x, batch_pred_y))
export_postprocess = codebase_cfg.get('export_postprocess', False)

Check warning on line 109 in mmdeploy/codebase/mmpose/deploy/pose_detection_model.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/deploy/pose_detection_model.py#L109

Added line #L109 was not covered by tests
if export_postprocess:
keypoints, scores = [_.cpu().numpy() for _ in batch_outputs]
preds = [

Check warning on line 112 in mmdeploy/codebase/mmpose/deploy/pose_detection_model.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/deploy/pose_detection_model.py#L112

Added line #L112 was not covered by tests
InstanceData(keypoints=keypoints, keypoint_scores=scores)
]
else:
batch_pred_x, batch_pred_y = batch_outputs
preds = self.head.decode((batch_pred_x, batch_pred_y))

Check warning on line 117 in mmdeploy/codebase/mmpose/deploy/pose_detection_model.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/deploy/pose_detection_model.py#L116-L117

Added lines #L116 - L117 were not covered by tests
elif codec.type in ['RegressionLabel', 'IntegralRegressionLabel']:
preds = self.head.decode(batch_outputs)
else:
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmpose/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import mspn_head, yolox_pose_head # noqa: F401,F403
from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403

__all__ = ['mspn_head', 'yolox_pose_head']
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head']
28 changes: 28 additions & 0 deletions mmdeploy/codebase/mmpose/models/heads/simcc_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.codebase.mmpose.codecs import get_simcc_maximum
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import get_codebase_config


@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.RTMCCHead.forward')
@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.SimCCHead.forward')
def simcc_head__forward(self, feats):
"""Rewrite `forward` of SimCCHead for default backend.
Args:
feats (tuple[Tensor]): Input features.
Returns:
key-points (torch.Tensor): Output keypoints in
shape of (N, K, 3)
"""
ctx = FUNCTION_REWRITER.get_context()
simcc_x, simcc_y = ctx.origin_func(self, feats)
codebase_cfg = get_codebase_config(ctx.cfg)
export_postprocess = codebase_cfg.get('export_postprocess', False)

Check warning on line 21 in mmdeploy/codebase/mmpose/models/heads/simcc_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/models/heads/simcc_head.py#L18-L21

Added lines #L18 - L21 were not covered by tests
if not export_postprocess:
return simcc_x, simcc_y
assert self.decoder.use_dark is False, \

Check warning on line 24 in mmdeploy/codebase/mmpose/models/heads/simcc_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/models/heads/simcc_head.py#L23-L24

Added lines #L23 - L24 were not covered by tests
'Do not support SimCCLabel with use_dark=True'
pts, scores = get_simcc_maximum(simcc_x, simcc_y)
pts /= self.decoder.simcc_split_ratio
return pts, scores

Check warning on line 28 in mmdeploy/codebase/mmpose/models/heads/simcc_head.py

View check run for this annotation

Codecov / codecov/patch

mmdeploy/codebase/mmpose/models/heads/simcc_head.py#L26-L28

Added lines #L26 - L28 were not covered by tests

0 comments on commit 50e9014

Please sign in to comment.