Skip to content

Commit

Permalink
update simcc_label
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon committed Sep 26, 2023
1 parent fa4674a commit 848f786
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions csrc/mmdeploy/codebase/mmpose/simcc_label.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class SimCCLabelDecode : public MMPose {
flip_test_ = params.value("flip_test", flip_test_);
simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_);
export_postprocess_ = params.value("export_postprocess", export_postprocess_);
if (export_postprocess_) {
simcc_split_ratio_ = 1.0;
}
if (params.contains("input_size")) {
from_value(params["input_size"], input_size_);
}
Expand All @@ -53,8 +56,14 @@ class SimCCLabelDecode : public MMPose {

Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}});
Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}});
float *keypoints_data = nullptr, *scores_data = nullptr;
if (!export_postprocess_) {
get_simcc_maximum(simcc_x, simcc_y, keypoints, scores);
keypoints_data = keypoints.data<float>();
scores_data = scores.data<float>();
} else {
keypoints_data = simcc_x.data<float>();
scores_data = simcc_y.data<float>();
}

std::vector<float> center;
Expand All @@ -63,22 +72,11 @@ class SimCCLabelDecode : public MMPose {
from_value(img_metas["scale"], scale);
PoseDetectorOutput output;

float* keypoints_data = keypoints.data<float>();
float* simcc_x_data = simcc_x.data<float>();
float* simcc_y_data = simcc_y.data<float>();

float* scores_data = scores.data<float>();
float scale_value = 200, x = -1, y = -1, s = 0;
for (int i = 0; i < simcc_x.shape(1); i++) {
if (export_postprocess_) {
x = *(simcc_x_data++);
y = *(simcc_x_data++);
s = *(scores_data++);
} else {
x = *(keypoints_data++) / simcc_split_ratio_;
y = *(keypoints_data++) / simcc_split_ratio_;
s = *(scores_data++);
}
x = *(keypoints_data++) / simcc_split_ratio_;
y = *(keypoints_data++) / simcc_split_ratio_;
s = *(scores_data++);

x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5;
y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5;
Expand Down

0 comments on commit 848f786

Please sign in to comment.