Skip to content

Commit

Permalink
export get_simcc_maximum for simcc (#2449)
Browse files Browse the repository at this point in the history
* update

* update for simcc csrc

* fix docker ci

* update simcc_label
  • Loading branch information
RunningLeon committed Sep 28, 2023
1 parent 1132e82 commit 4cf0f92
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
export TAG=$TAG_PREFIX
echo "TAG=${TAG}" >> $GITHUB_ENV
echo $TAG
docker ./docker/Release/ -t ${TAG} --no-cache
docker build ./docker/Release/ -t ${TAG} --no-cache
docker push $TAG
- name: Push docker image with released tag
if: startsWith(github.ref, 'refs/tags/') == true
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@ jobs:
echo $MMDEPLOY_VERSION
echo "MMDEPLOY_VERSION=$MMDEPLOY_VERSION" >> $GITHUB_ENV
echo "OUTPUT_DIR=$PREBUILD_DIR/$MMDEPLOY_VERSION" >> $GITHUB_ENV
pip install twine
python3 -m pip install twine --user
- name: Upload mmdeploy
continue-on-error: true
run: |
cd $OUTPUT_DIR/mmdeploy
ls -sha *.whl
twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }}
python3 -m twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }}
- name: Upload mmdeploy_runtime
continue-on-error: true
run: |
cd $OUTPUT_DIR/mmdeploy_runtime
ls -sha *.whl
twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }}
python3 -m twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }}
- name: Check assets
run: |
ls -sha $OUTPUT_DIR/sdk
Expand Down
4 changes: 4 additions & 0 deletions configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@
0: 'batch'
}
})

codebase_config = dict(
export_postprocess=False # do not export get_simcc_maximum
)
26 changes: 18 additions & 8 deletions csrc/mmdeploy/codebase/mmpose/simcc_label.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class SimCCLabelDecode : public MMPose {
auto& params = config["params"];
flip_test_ = params.value("flip_test", flip_test_);
simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_);
export_postprocess_ = params.value("export_postprocess", export_postprocess_);
if (export_postprocess_) {
simcc_split_ratio_ = 1.0;
}
if (params.contains("input_size")) {
from_value(params["input_size"], input_size_);
}
Expand All @@ -52,26 +56,31 @@ class SimCCLabelDecode : public MMPose {

Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}});
Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}});
get_simcc_maximum(simcc_x, simcc_y, keypoints, scores);
float *keypoints_data = nullptr, *scores_data = nullptr;
if (!export_postprocess_) {
get_simcc_maximum(simcc_x, simcc_y, keypoints, scores);
keypoints_data = keypoints.data<float>();
scores_data = scores.data<float>();
} else {
keypoints_data = simcc_x.data<float>();
scores_data = simcc_y.data<float>();
}

std::vector<float> center;
std::vector<float> scale;
from_value(img_metas["center"], center);
from_value(img_metas["scale"], scale);
PoseDetectorOutput output;

float* keypoints_data = keypoints.data<float>();
float* scores_data = scores.data<float>();
float scale_value = 200, x = -1, y = -1, s = 0;
for (int i = 0; i < simcc_x.shape(1); i++) {
x = *(keypoints_data + 0) / simcc_split_ratio_;
y = *(keypoints_data + 1) / simcc_split_ratio_;
x = *(keypoints_data++) / simcc_split_ratio_;
y = *(keypoints_data++) / simcc_split_ratio_;
s = *(scores_data++);

x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5;
y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5;
s = *(scores_data + 0);
output.key_points.push_back({{x, y}, s});
keypoints_data += 2;
scores_data += 1;
}
return to_value(output);
}
Expand Down Expand Up @@ -104,6 +113,7 @@ class SimCCLabelDecode : public MMPose {

private:
bool flip_test_{false};
bool export_postprocess_{false};
bool shift_heatmap_{false};
float simcc_split_ratio_{2.0};
std::vector<int> input_size_{192, 256};
Expand Down
5 changes: 5 additions & 0 deletions mmdeploy/codebase/mmpose/codecs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .post_processing import get_simcc_maximum

__all__ = ['get_simcc_maximum']
33 changes: 33 additions & 0 deletions mmdeploy/codebase/mmpose/codecs/post_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def get_simcc_maximum(simcc_x: torch.Tensor,
simcc_y: torch.Tensor) -> torch.Tensor:
"""Get maximum response location and value from simcc representations.
rewrite to support `torch.Tensor` input type.
Args:
simcc_x (torch.Tensor): x-axis SimCC in shape (N, K, Wx)
simcc_y (torch.Tensor): y-axis SimCC in shape (N, K, Wy)
Returns:
tuple:
- locs (torch.Tensor): locations of maximum heatmap responses in shape
(N, K, 2)
- vals (torch.Tensor): values of maximum heatmap responses in shape
(N, K)
"""
N, K, _ = simcc_x.shape
simcc_x = simcc_x.flatten(0, 1)
simcc_y = simcc_y.flatten(0, 1)
x_locs = simcc_x.argmax(dim=1, keepdim=True)
y_locs = simcc_y.argmax(dim=1, keepdim=True)
locs = torch.cat((x_locs, y_locs), dim=1).to(torch.float32)
max_val_x, _ = simcc_x.max(dim=1, keepdim=True)
max_val_y, _ = simcc_y.max(dim=1, keepdim=True)
vals, _ = torch.cat([max_val_x, max_val_y], dim=1).min(dim=1)
locs = locs.reshape(N, K, 2)
vals = vals.reshape(N, K)
return locs, vals
6 changes: 5 additions & 1 deletion mmdeploy/codebase/mmpose/deploy/pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from mmengine.registry import Registry

from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
from mmdeploy.utils import (Codebase, Task, get_codebase_config,
get_input_shape, get_root_logger)


def process_model_config(
Expand Down Expand Up @@ -362,6 +363,9 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
params['post_process'] = 'megvii'
params['modulate_kernel'] = self.model_cfg.kernel_sizes[-1]
elif codec.type == 'SimCCLabel':
export_postprocess = get_codebase_config(self.deploy_cfg).get(
'export_postprocess', False)
params['export_postprocess'] = export_postprocess
component = 'SimCCLabelDecode'
elif codec.type == 'RegressionLabel':
component = 'DeepposeRegressionHeadDecode'
Expand Down
12 changes: 10 additions & 2 deletions mmdeploy/codebase/mmpose/deploy/pose_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,20 @@ def forward(self,
if self.model_cfg.model.type == 'YOLODetector':
return self.pack_yolox_pose_result(batch_outputs, data_samples)

codebase_cfg = get_codebase_config(self.deploy_cfg)
codec = self.model_cfg.codec
if isinstance(codec, (list, tuple)):
codec = codec[-1]
if codec.type == 'SimCCLabel':
batch_pred_x, batch_pred_y = batch_outputs
preds = self.head.decode((batch_pred_x, batch_pred_y))
export_postprocess = codebase_cfg.get('export_postprocess', False)
if export_postprocess:
keypoints, scores = [_.cpu().numpy() for _ in batch_outputs]
preds = [
InstanceData(keypoints=keypoints, keypoint_scores=scores)
]
else:
batch_pred_x, batch_pred_y = batch_outputs
preds = self.head.decode((batch_pred_x, batch_pred_y))
elif codec.type in ['RegressionLabel', 'IntegralRegressionLabel']:
preds = self.head.decode(batch_outputs)
else:
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmpose/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import mspn_head, yolox_pose_head # noqa: F401,F403
from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403

__all__ = ['mspn_head', 'yolox_pose_head']
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head']
28 changes: 28 additions & 0 deletions mmdeploy/codebase/mmpose/models/heads/simcc_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.codebase.mmpose.codecs import get_simcc_maximum
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import get_codebase_config


@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.RTMCCHead.forward')
@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.SimCCHead.forward')
def simcc_head__forward(self, feats):
"""Rewrite `forward` of SimCCHead for default backend.
Args:
feats (tuple[Tensor]): Input features.
Returns:
key-points (torch.Tensor): Output keypoints in
shape of (N, K, 3)
"""
ctx = FUNCTION_REWRITER.get_context()
simcc_x, simcc_y = ctx.origin_func(self, feats)
codebase_cfg = get_codebase_config(ctx.cfg)
export_postprocess = codebase_cfg.get('export_postprocess', False)
if not export_postprocess:
return simcc_x, simcc_y
assert self.decoder.use_dark is False, \
'Do not support SimCCLabel with use_dark=True'
pts, scores = get_simcc_maximum(simcc_x, simcc_y)
pts /= self.decoder.simcc_split_ratio
return pts, scores

0 comments on commit 4cf0f92

Please sign in to comment.