diff --git a/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp b/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp index fb4af47126..6ad142f6fa 100644 --- a/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp +++ b/csrc/mmdeploy/codebase/mmpose/simcc_label.cpp @@ -27,6 +27,9 @@ class SimCCLabelDecode : public MMPose { flip_test_ = params.value("flip_test", flip_test_); simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_); export_postprocess_ = params.value("export_postprocess", export_postprocess_); + if (export_postprocess_) { + simcc_split_ratio_ = 1.0; + } if (params.contains("input_size")) { from_value(params["input_size"], input_size_); } @@ -53,8 +56,14 @@ class SimCCLabelDecode : public MMPose { Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}}); Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}}); + float *keypoints_data = nullptr, *scores_data = nullptr; if (!export_postprocess_) { get_simcc_maximum(simcc_x, simcc_y, keypoints, scores); + keypoints_data = keypoints.data(); + scores_data = scores.data(); + } else { + keypoints_data = simcc_x.data(); + scores_data = simcc_y.data(); } std::vector center; @@ -63,22 +72,11 @@ class SimCCLabelDecode : public MMPose { from_value(img_metas["scale"], scale); PoseDetectorOutput output; - float* keypoints_data = keypoints.data(); - float* simcc_x_data = simcc_x.data(); - float* simcc_y_data = simcc_y.data(); - - float* scores_data = scores.data(); float scale_value = 200, x = -1, y = -1, s = 0; for (int i = 0; i < simcc_x.shape(1); i++) { - if (export_postprocess_) { - x = *(simcc_x_data++); - y = *(simcc_x_data++); - s = *(scores_data++); - } else { - x = *(keypoints_data++) / simcc_split_ratio_; - y = *(keypoints_data++) / simcc_split_ratio_; - s = *(scores_data++); - } + x = *(keypoints_data++) / simcc_split_ratio_; + y = *(keypoints_data++) / simcc_split_ratio_; + s = *(scores_data++); x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5; y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5;