Skip to content

Commit

Permalink
add yolov8-pose support
Browse files Browse the repository at this point in the history
  • Loading branch information
Neutree committed Jun 13, 2024
1 parent afe293a commit 4158d35
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 5 deletions.
15 changes: 13 additions & 2 deletions components/nn/include/maix_nn_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace maix::nn
* @maixcdk maix.nn.Object.Object
*/
Object(int x = 0, int y = 0, int w = 0, int h = 0, int class_id = 0, float score = 0, std::vector<int> points = std::vector<int>())
: x(x), y(y), w(w), h(h), class_id(class_id), score(score), points(points)
: x(x), y(y), w(w), h(h), class_id(class_id), score(score), points(points), temp(NULL)
{
}

Expand Down Expand Up @@ -91,6 +91,12 @@ namespace maix::nn
* @maixpy maix.nn.Object.points
*/
std::vector<int> points;


/**
* For temperary usage, not for MaixPy API
*/
void *temp;
};

/**
Expand All @@ -112,7 +118,7 @@ namespace maix::nn
* @maixcdk maix.nn.ObjectFloat.ObjectFloat
*/
ObjectFloat(float x = 0, float y = 0, float w = 0, float h = 0, float class_id = 0, float score = 0, std::vector<float> points = std::vector<float>())
: x(x), y(y), w(w), h(h), class_id(class_id), score(score), points(points)
: x(x), y(y), w(w), h(h), class_id(class_id), score(score), points(points), temp(NULL)
{
}

Expand Down Expand Up @@ -172,6 +178,11 @@ namespace maix::nn
* @maixpy maix.nn.ObjectFloat.points
*/
std::vector<float> points;

/**
* For temperary usage, not for MaixPy API
*/
void *temp;
};
}

144 changes: 141 additions & 3 deletions components/nn/include/maix_nn_yolov8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@

namespace maix::nn
{
class _KpInfo
{
public:
_KpInfo(int idx, int ax, int ay, float stride)
: idx(idx), anchor_x(ax), anchor_y(ay), stride(stride)
{
}
int idx;
int anchor_x;
int anchor_y;
float stride;
};

/**
* YOLOv8 class
* @maixpy maix.nn.YOLOv8
Expand All @@ -32,6 +45,8 @@ namespace maix::nn
YOLOv8(const string &model = "")
{
_model = nullptr;
_type_pose = false;
_type = "detector";
if (!model.empty())
{
err::Err e = load(model);
Expand Down Expand Up @@ -188,6 +203,23 @@ namespace maix::nn
log::error("labels key not found");
return err::ERR_ARGS;
}
if (_extra_info.find("type") != _extra_info.end())
{
_type = _extra_info["type"];
if (_type == "pose")
{
_type_pose = true;
}
else if (_type != "detector")
{
log::error("type [%s] not support, suport detector and pose", _type.c_str());
return err::ERR_ARGS;
}
}
else
{
_type = "detector";
}
std::vector<nn::LayerInfo> inputs = _model->inputs_info();
_input_size = image::Size(inputs[0].shape[3], inputs[0].shape[2]);
log::print("\tinput size: %dx%d\n\n", _input_size.width(), _input_size.height());
Expand Down Expand Up @@ -223,14 +255,17 @@ namespace maix::nn
* @param conf_th Confidence threshold, default 0.5.
* @param iou_th IoU threshold, default 0.45.
* @param fit Resize method, default image.Fit.FIT_CONTAIN.
* @param keypoint_th keypoint threshold, default 0.5, only for yolov8-pose model.
* @throw If image format not match model input format, will throw err::Exception.
* @return Object list. In C++, you should delete it after use.
* If model is yolov8-pose, object's points have value, and if points' value < 0 means that point is invalid(conf < keypoint_th).
* @maixpy maix.nn.YOLOv8.detect
*/
std::vector<nn::Object> *detect(image::Image &img, float conf_th = 0.5, float iou_th = 0.45, maix::image::Fit fit = maix::image::FIT_CONTAIN)
std::vector<nn::Object> *detect(image::Image &img, float conf_th = 0.5, float iou_th = 0.45, maix::image::Fit fit = maix::image::FIT_CONTAIN, float keypoint_th = 0.5)
{
this->_conf_th = conf_th;
this->_iou_th = iou_th;
this->_keypoint_th = keypoint_th;
if (img.format() != _input_img_fmt)
{
throw err::Exception("image format not match, input_type: " + image::fmt_names[_input_img_fmt] + ", image format: " + image::fmt_names[img.format()]);
Expand Down Expand Up @@ -286,6 +321,46 @@ namespace maix::nn
return _input_img_fmt;
}

/**
* Draw pose keypoints on image
* @param img image object, maix.image.Image type.
* @param points keypoits, int list type, [x, y, x, y ...]
* @param radius radius of points.
* @param color color of points.
* @maixpy maix.nn.YOLOv8.draw_pose
*/
void draw_pose(image::Image &img, std::vector<int> points, int radius = 4, image::Color color = image::COLOR_RED)
{
if (points.size() < 2 || points.size() % 2 != 0)
{
throw std::runtime_error("keypoints size must >= 2 and multiple of 2");
return;
}
int pos[] = {9, 7, 7, 5, 6, 8, 8, 10, 5, 11, 6, 12, 5, 6, 11, 12, 11, 13, 15, 13, 14, 12, 14, 16};
for (int i = 0; i < 12; ++i)
{
int x1 = points[pos[i * 2] * 2];
int y1 = points[pos[i * 2] * 2 + 1];
int x2 = points[pos[i * 2 + 1] * 2];
int y2 = points[pos[i * 2 + 1] * 2 + 1];
if (x1 < 0 || y1 < 0 || x2 < 0 || y2 < 0)
continue;
img.draw_line(x1, y1, x2, y2, image::COLOR_RED, 2);
}
int x = (points[5 * 2] + points[6 * 2]) / 2;
int y = (points[5 * 2 + 1] + points[6 * 2 + 1]) / 2;
if (!(points[5 * 2] < 0 || points[5 * 2 + 1] < 0 || points[6 * 2] < 0 || points[6 * 2 + 1] < 0 || x < 0 || y < 0 || points[0] < 0 || points[1] < 0))
img.draw_line(points[0], points[1], x, y, image::COLOR_RED, 2);
for (size_t i = 0; i < points.size() / 2; ++i)
{
int x = points[i * 2];
int y = points[i * 2 + 1];
if (x < 0 || y < 0)
continue;
img.draw_circle(x, y, radius, color, -1);
}
}

public:
/**
* Labels list
Expand Down Expand Up @@ -318,33 +393,46 @@ namespace maix::nn
std::map<string, string> _extra_info;
float _conf_th = 0.5;
float _iou_th = 0.45;
float _keypoint_th = 0.5;
std::string _type;
bool _type_pose;

private:
std::vector<nn::Object> *_post_process(tensor::Tensors *outputs, int img_w, int img_h, maix::image::Fit fit)
{
std::vector<nn::Object> *objects = new std::vector<nn::Object>();
_decode_objs(*objects, outputs, _conf_th, _input_size.width(), _input_size.height());
tensor::Tensor *kp_out = _decode_objs(*objects, outputs, _conf_th, _input_size.width(), _input_size.height());
if (objects->size() > 0)
{
std::vector<nn::Object> *objects_total = objects;
objects = _nms(*objects);
delete objects_total;
}
// decode keypoints
if (_type_pose)
{
_decode_keypoints(*objects, kp_out);
}
if (objects->size() > 0)
_correct_bbox(*objects, img_w, img_h, fit);
return objects;
}

void _decode_objs(std::vector<nn::Object> &objs, tensor::Tensors *outputs, float conf_thresh, int w, int h)
tensor::Tensor *_decode_objs(std::vector<nn::Object> &objs, tensor::Tensors *outputs, float conf_thresh, int w, int h)
{
tensor::Tensor *score_out = NULL; // shape 1, 80, 8400, 1
tensor::Tensor *box_out = NULL; // shape 1, 1, 4, 8400
tensor::Tensor *kp_out = NULL; // shape 1, 51, 8400, 1
for (auto i : *outputs)
{
if (i.second->shape()[2] == 4)
{
box_out = i.second;
}
else if (i.second->shape()[1] == 51) // 17 * 3
{
kp_out = i.second;
}
else
{
score_out = i.second;
Expand Down Expand Up @@ -382,10 +470,13 @@ namespace maix::nn
float bbox_w = (ax + 0.5 + dets_ptr[offset + total_box_num * 2]) * stride[i] - bbox_x;
float bbox_h = (ay + 0.5 + dets_ptr[offset + total_box_num * 3]) * stride[i] - bbox_y;
Object obj(bbox_x, bbox_y, bbox_w, bbox_h, class_id, obj_score);
_KpInfo *kp_info = new _KpInfo(offset, ax, ay, stride[i]);
obj.temp = (void *)kp_info;
objs.push_back(obj);
}
}
}
return kp_out;
}

std::vector<nn::Object> *_nms(std::vector<nn::Object> &objs)
Expand All @@ -411,10 +502,42 @@ namespace maix::nn
{
if (a.score != 0)
result->push_back(a);
else
{
delete (_KpInfo *)a.temp;
a.temp = NULL;
}
}
return result;
}

void _decode_keypoints(std::vector<nn::Object> &objs, tensor::Tensor *kp_out)
{
float *data = (float *)kp_out->data();
int total_box_num = kp_out->shape()[2]; // 1, 51, 8400, 1
for (size_t i = 0; i < objs.size(); ++i)
{
nn::Object &o = objs.at(i);
_KpInfo *kp_info = (_KpInfo *)o.temp;
float *p = data + kp_info->idx;
for (int k = 0; k < 17; ++k)
{
float score = _sigmoid(p[(k * 3 + 2) * total_box_num]);
int x = -1;
int y = -1;
if (score > _keypoint_th)
{
x = (p[(k * 3) * total_box_num] * 2.0 + kp_info->anchor_x) * kp_info->stride;
y = (p[(k * 3 + 1) * total_box_num] * 2.0 + kp_info->anchor_y) * kp_info->stride;
}
o.points.push_back(x);
o.points.push_back(y);
}
delete (_KpInfo *)o.temp;
o.temp = NULL;
}
}

void _correct_bbox(std::vector<nn::Object> &objs, int img_w, int img_h, maix::image::Fit fit)
{
if (img_w == _input_size.width() && img_h == _input_size.height())
Expand All @@ -429,6 +552,11 @@ namespace maix::nn
obj.y *= scale_y;
obj.w *= scale_x;
obj.h *= scale_y;
for (size_t i = 0; i < obj.points.size() / 2; ++i)
{
obj.points.at(i * 2) *= scale_x;
obj.points.at(i * 2 + 1) *= scale_y;
}
}
}
else if (fit == maix::image::FIT_CONTAIN)
Expand All @@ -445,6 +573,11 @@ namespace maix::nn
obj.y = (obj.y - pad_h) * scale_reverse;
obj.w *= scale_reverse;
obj.h *= scale_reverse;
for (size_t i = 0; i < obj.points.size() / 2; ++i)
{
obj.points.at(i * 2) = (obj.points.at(i * 2) - pad_w) * scale_reverse;
obj.points.at(i * 2 + 1) = (obj.points.at(i * 2 + 1) - pad_h) * scale_reverse;
}
}
}
else if (fit == maix::image::FIT_COVER)
Expand All @@ -461,6 +594,11 @@ namespace maix::nn
obj.y = (obj.y + pad_h) * scale_reverse;
obj.w *= scale_reverse;
obj.h *= scale_reverse;
for (size_t i = 0; i < obj.points.size() / 2; ++i)
{
obj.points.at(i * 2) = (obj.points.at(i * 2) - pad_w) * scale_reverse;
obj.points.at(i * 2 + 1) = (obj.points.at(i * 2 + 1) - pad_h) * scale_reverse;
}
}
}
else
Expand Down
2 changes: 2 additions & 0 deletions components/vision/src/maix_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,8 @@ namespace maix::image
for(size_t i=0; i<keypoints.size() / 2; ++i)
{
cv::Point center(keypoints[i * 2], keypoints[i * 2 + 1]);
if(center.x < 0 || center.y < 0)
continue;
int radius = size;
cv::circle(img, center, radius, cv_color, thickness);
}
Expand Down
2 changes: 2 additions & 0 deletions examples/nn_yolov8/main/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ int _main(int argc, char *argv[])
img->draw_rect(r.x, r.y, r.w, r.h, maix::image::Color::from_rgb(255, 0, 0));
snprintf(tmp_chars, sizeof(tmp_chars), "%s: %.2f", detector.labels[r.class_id].c_str(), r.score);
img->draw_string(r.x, r.y, tmp_chars, maix::image::Color::from_rgb(255, 0, 0));
detector.draw_pose(*img, r.points, 4, image::COLOR_RED);
}
img->save("result.jpg");
delete result;
Expand All @@ -84,6 +85,7 @@ int _main(int argc, char *argv[])
log::info("result: %s", r.to_str().c_str());
img->draw_rect(r.x, r.y, r.w, r.h, maix::image::Color::from_rgb(255, 0, 0));
img->draw_string(r.x, r.y, detector.labels[r.class_id], maix::image::Color::from_rgb(255, 0, 0));
detector.draw_pose(*img, r.points, 4, image::COLOR_RED);
}
disp.show(*img);
delete result;
Expand Down

0 comments on commit 4158d35

Please sign in to comment.